Complex NN Integration Path¶
Use this path when the default BaseNNModel._train_model() is no longer enough.
USGAN is the clearest example.
When This Path Is Correct¶
Move to the complex path if you need one or more of these:
Multiple optimizers (e.g. generator + discriminator)
Alternating update schedules (e.g. train G for k steps, then D for 1 step)
Different training branches for different submodules
Explicit pretraining before the main training loop
Custom checkpoint-selection logic that still follows task semantics
Examples in PyPOTS:
USGAN— generator/discriminator alternationCRLI— multi-optimizer clustering trainingVaDER— pretraining before the main training phase
What Stays The Same¶
Even when you override _train_model(), these contracts should stay stable:
The wrapper still exposes the same public
fit()andpredict()APIThe model still returns the correct inference-time task result
Best checkpoint tracking still exists
Early stopping still uses a meaningful metric
The final trained wrapper still loads the best checkpoint before inference
If those invariants disappear, review becomes much harder.
USGAN as the Reference Pattern¶
USGAN still inherits BaseNNImputer, but it cannot use the default training loop.
Why?
It has
G_optimizerandD_optimizerIt alternates generator and discriminator updates
One batch may execute multiple optimizer steps
So USGAN overrides _train_model() in the wrapper.
What USGAN._train_model() Preserves¶
The useful lesson is not “copy this loop.” The useful lesson is what it preserves while changing the orchestration:
Resets best-loss state at the start
Runs explicit train and validation phases
Tracks the metric used for best-model selection
Updates patience for early stopping
Saves better checkpoints when configured
Leaves inference-time task behavior unchanged
That is the contract to preserve.
Implementation Guide¶
Step 1: Start From the Standard Path¶
Begin by setting up your model as if it were a standard NN model (see Standard NN Integration Path). Then identify which part of the training loop needs to change.
Typical reasons to override _train_model():
You have two or more optimizers that need separate
step()callsDifferent sub-models need different forward passes per batch
You need a pretraining phase before the main loop
Step 2: Implement the Custom Training Loop¶
Override _train_model() in your model.py. Here is a skeleton based on USGAN:
# In your model.py
class YourGANModel(BaseNNImputer):
def __init__(self, ...):
super().__init__(...)
# Set up dual optimizers
self.model = _YourGANCore(...)
self._print_model_size()
self._send_model_to_given_device()
# Generator optimizer
self.G_optimizer = Adam(lr=0.001)
self.G_optimizer.init_optimizer(self.model.generator.parameters())
# Discriminator optimizer
self.D_optimizer = Adam(lr=0.001)
self.D_optimizer.init_optimizer(self.model.discriminator.parameters())
def _train_model(self, train_dataloader, val_dataloader=None):
# Reset tracking state
self.best_loss = float("inf")
self.best_epoch = 0
patience_counter = 0
for epoch in range(self.epochs):
self.model.train()
epoch_train_loss_collector = []
for idx, data in enumerate(train_dataloader):
inputs = self._assemble_input_for_training(data)
# --- Discriminator update ---
for _ in range(self.D_steps):
self.D_optimizer.zero_grad()
results = self.model(inputs, training_object="discriminator")
results["D_loss"].backward()
self.D_optimizer.step()
# --- Generator update ---
for _ in range(self.G_steps):
self.G_optimizer.zero_grad()
results = self.model(inputs, training_object="generator")
results["G_loss"].backward()
self.G_optimizer.step()
epoch_train_loss_collector.append(
results["G_loss"].item()
)
# --- Validation phase ---
if val_dataloader is not None:
self.model.eval()
val_loss_collector = []
with torch.no_grad():
for idx, data in enumerate(val_dataloader):
inputs = self._assemble_input_for_validating(data)
results = self.model(inputs, calc_criterion=True)
val_loss_collector.append(
results["metric"].item()
)
mean_val_loss = np.mean(val_loss_collector)
# --- Best model tracking ---
if mean_val_loss < self.best_loss:
self.best_loss = mean_val_loss
self.best_epoch = epoch
self.best_model_dict = deepcopy(self.model.state_dict())
patience_counter = 0
else:
patience_counter += 1
# --- Early stopping ---
if (self.patience is not None
and patience_counter >= self.patience):
break
# --- TensorBoard logging ---
if self.summary_writer is not None:
mean_train_loss = np.mean(epoch_train_loss_collector)
self._save_log_into_tb_file(
epoch, "training", {"loss": mean_train_loss}
)
if val_dataloader is not None:
self._save_log_into_tb_file(
epoch, "validating", {"loss": mean_val_loss}
)
Step 3: Handle the Core Differently¶
For a GAN-like model, your core.py forward pass may need a training_object parameter
to distinguish between generator and discriminator forward passes:
class _YourGANCore(nn.Module):
def __init__(self, ...):
super().__init__()
self.generator = ...
self.discriminator = ...
def forward(self, inputs, training_object="generator",
calc_criterion=False):
X, missing_mask = inputs["X"], inputs["missing_mask"]
# Generator forward pass
imputed_data = self.generator(X, missing_mask)
imputed_data = missing_mask * X + (1 - missing_mask) * imputed_data
results = {"imputation": imputed_data}
if training_object == "discriminator":
# Discriminator-specific loss
d_prob = self.discriminator(imputed_data.detach(), missing_mask)
results["D_loss"] = self._d_loss(d_prob, missing_mask)
elif training_object == "generator":
# Generator-specific loss (including adversarial + reconstruction)
d_prob = self.discriminator(imputed_data, missing_mask)
results["G_loss"] = self._g_loss(
d_prob, imputed_data, X, missing_mask
)
if calc_criterion:
# For validation metric
X_ori = inputs["X_ori"]
indicating_mask = inputs["indicating_mask"]
if not self.training:
results["metric"] = self.validation_metric(
imputed_data, X_ori, indicating_mask
)
return results
What You Are Allowed to Customize¶
Override _train_model() only for orchestration-level reasons such as:
Custom optimizer order
Custom gradient flow
Pretraining stages
Multi-branch loss collection
Special logging needs tied to the custom schedule
Do not override it just to move ordinary data assembly or forward logic around.
High-Risk Mistakes¶
The complex path usually fails in four places:
Best checkpoint selected from the wrong signal — e.g. using generator loss instead of a proper validation metric
Patience updated inconsistently — e.g. forgetting to update patience in some code paths
Training and validation use different input key assumptions — e.g. validation assembly expects keys that aren’t provided
Inference still works, but the task result key changes silently — e.g. returning
"generated"instead of"imputation"
Quick self-check: “Did I only change training orchestration, or did I accidentally change the public model contract too?”
Definition of Done¶
Your complex NN integration is done when:
Each optimizer branch is exercised during training
Validation still drives model selection in a clear way
The best checkpoint is restored after training
predict()still returns the expected task resultTargeted tests prove both training and inference paths
If only one optimizer and one ordinary loss remain, the model probably belongs back on the Standard NN Integration Path.