Complex NN Integration Path

Use this path when the default BaseNNModel._train_model() is no longer enough.

USGAN is the clearest example.

When This Path Is Correct

Move to the complex path if you need one or more of these:

  • Multiple optimizers (e.g. generator + discriminator)

  • Alternating update schedules (e.g. train G for k steps, then D for 1 step)

  • Different training branches for different submodules

  • Explicit pretraining before the main training loop

  • Custom checkpoint-selection logic that still follows task semantics

Examples in PyPOTS:

  • USGAN — generator/discriminator alternation

  • CRLI — multi-optimizer clustering training

  • VaDER — pretraining before the main training phase

What Stays The Same

Even when you override _train_model(), these contracts should stay stable:

  • The wrapper still exposes the same public fit() and predict() API

  • The model still returns the correct inference-time task result

  • Best checkpoint tracking still exists

  • Early stopping still uses a meaningful metric

  • The final trained wrapper still loads the best checkpoint before inference

If those invariants disappear, review becomes much harder.

USGAN as the Reference Pattern

USGAN still inherits BaseNNImputer, but it cannot use the default training loop.

Why?

  • It has G_optimizer and D_optimizer

  • It alternates generator and discriminator updates

  • One batch may execute multiple optimizer steps

So USGAN overrides _train_model() in the wrapper.

What USGAN._train_model() Preserves

The useful lesson is not “copy this loop.” The useful lesson is what it preserves while changing the orchestration:

  • Resets best-loss state at the start

  • Runs explicit train and validation phases

  • Tracks the metric used for best-model selection

  • Updates patience for early stopping

  • Saves better checkpoints when configured

  • Leaves inference-time task behavior unchanged

That is the contract to preserve.

Implementation Guide

Step 1: Start From the Standard Path

Begin by setting up your model as if it were a standard NN model (see Standard NN Integration Path). Then identify which part of the training loop needs to change.

Typical reasons to override _train_model():

  • You have two or more optimizers that need separate step() calls

  • Different sub-models need different forward passes per batch

  • You need a pretraining phase before the main loop

Step 2: Implement the Custom Training Loop

Override _train_model() in your model.py. Here is a skeleton based on USGAN:

# In your model.py

class YourGANModel(BaseNNImputer):
    def __init__(self, ...):
        super().__init__(...)

        # Set up dual optimizers
        self.model = _YourGANCore(...)
        self._print_model_size()
        self._send_model_to_given_device()

        # Generator optimizer
        self.G_optimizer = Adam(lr=0.001)
        self.G_optimizer.init_optimizer(self.model.generator.parameters())

        # Discriminator optimizer
        self.D_optimizer = Adam(lr=0.001)
        self.D_optimizer.init_optimizer(self.model.discriminator.parameters())

    def _train_model(self, train_dataloader, val_dataloader=None):
        # Reset tracking state
        self.best_loss = float("inf")
        self.best_epoch = 0
        patience_counter = 0

        for epoch in range(self.epochs):
            self.model.train()
            epoch_train_loss_collector = []

            for idx, data in enumerate(train_dataloader):
                inputs = self._assemble_input_for_training(data)

                # --- Discriminator update ---
                for _ in range(self.D_steps):
                    self.D_optimizer.zero_grad()
                    results = self.model(inputs, training_object="discriminator")
                    results["D_loss"].backward()
                    self.D_optimizer.step()

                # --- Generator update ---
                for _ in range(self.G_steps):
                    self.G_optimizer.zero_grad()
                    results = self.model(inputs, training_object="generator")
                    results["G_loss"].backward()
                    self.G_optimizer.step()

                epoch_train_loss_collector.append(
                    results["G_loss"].item()
                )

            # --- Validation phase ---
            if val_dataloader is not None:
                self.model.eval()
                val_loss_collector = []
                with torch.no_grad():
                    for idx, data in enumerate(val_dataloader):
                        inputs = self._assemble_input_for_validating(data)
                        results = self.model(inputs, calc_criterion=True)
                        val_loss_collector.append(
                            results["metric"].item()
                        )

                mean_val_loss = np.mean(val_loss_collector)

                # --- Best model tracking ---
                if mean_val_loss < self.best_loss:
                    self.best_loss = mean_val_loss
                    self.best_epoch = epoch
                    self.best_model_dict = deepcopy(self.model.state_dict())
                    patience_counter = 0
                else:
                    patience_counter += 1

                # --- Early stopping ---
                if (self.patience is not None
                    and patience_counter >= self.patience):
                    break

            # --- TensorBoard logging ---
            if self.summary_writer is not None:
                mean_train_loss = np.mean(epoch_train_loss_collector)
                self._save_log_into_tb_file(
                    epoch, "training", {"loss": mean_train_loss}
                )
                if val_dataloader is not None:
                    self._save_log_into_tb_file(
                        epoch, "validating", {"loss": mean_val_loss}
                    )

Step 3: Handle the Core Differently

For a GAN-like model, your core.py forward pass may need a training_object parameter to distinguish between generator and discriminator forward passes:

class _YourGANCore(nn.Module):
    def __init__(self, ...):
        super().__init__()
        self.generator = ...
        self.discriminator = ...

    def forward(self, inputs, training_object="generator",
                calc_criterion=False):
        X, missing_mask = inputs["X"], inputs["missing_mask"]

        # Generator forward pass
        imputed_data = self.generator(X, missing_mask)
        imputed_data = missing_mask * X + (1 - missing_mask) * imputed_data

        results = {"imputation": imputed_data}

        if training_object == "discriminator":
            # Discriminator-specific loss
            d_prob = self.discriminator(imputed_data.detach(), missing_mask)
            results["D_loss"] = self._d_loss(d_prob, missing_mask)

        elif training_object == "generator":
            # Generator-specific loss (including adversarial + reconstruction)
            d_prob = self.discriminator(imputed_data, missing_mask)
            results["G_loss"] = self._g_loss(
                d_prob, imputed_data, X, missing_mask
            )

        if calc_criterion:
            # For validation metric
            X_ori = inputs["X_ori"]
            indicating_mask = inputs["indicating_mask"]
            if not self.training:
                results["metric"] = self.validation_metric(
                    imputed_data, X_ori, indicating_mask
                )

        return results

What You Are Allowed to Customize

Override _train_model() only for orchestration-level reasons such as:

  • Custom optimizer order

  • Custom gradient flow

  • Pretraining stages

  • Multi-branch loss collection

  • Special logging needs tied to the custom schedule

Do not override it just to move ordinary data assembly or forward logic around.

High-Risk Mistakes

The complex path usually fails in four places:

  1. Best checkpoint selected from the wrong signal — e.g. using generator loss instead of a proper validation metric

  2. Patience updated inconsistently — e.g. forgetting to update patience in some code paths

  3. Training and validation use different input key assumptions — e.g. validation assembly expects keys that aren’t provided

  4. Inference still works, but the task result key changes silently — e.g. returning "generated" instead of "imputation"

Quick self-check: “Did I only change training orchestration, or did I accidentally change the public model contract too?”

Definition of Done

Your complex NN integration is done when:

  • Each optimizer branch is exercised during training

  • Validation still drives model selection in a clear way

  • The best checkpoint is restored after training

  • predict() still returns the expected task result

  • Targeted tests prove both training and inference paths

If only one optimizer and one ordinary loss remain, the model probably belongs back on the Standard NN Integration Path.