Standard NN Integration Path¶
Use this path when one optimizer and the default BaseNNModel._train_model() are enough.
SAITS is the best reference model for this path.
When This Path Is Correct¶
Choose the standard NN path when:
One optimizer is enough
One main objective drives training
You can express training and validation with
results["loss"]andresults["metric"]You do not need alternating update schedules
If any of these are not true, switch early to Complex NN Integration Path.
Start From the Task Template¶
PyPOTS ships task templates to help you get started:
pypots/imputation/template/
pypots/forecasting/template/
pypots/classification/template/
pypots/clustering/template/
Use them as scaffolding, not as the final spec. The real contract comes from the task base class. Do not copy placeholder output names blindly.
Step-by-Step Implementation Guide¶
Step 1: Pick the Task Contract¶
Before writing code, decide:
Which task base class you inherit (e.g.
BaseNNImputer)Which public helper method must work (e.g.
impute(),forecast(),classify())Which public result key must exist (e.g.
"imputation"for imputation models)
For example, an imputation model must end up with "imputation" in the dict returned by predict().
Step 2: Implement core.py¶
Your core should focus on model computation only.
For the standard NN path, forward() follows this pattern:
Read tensors from
inputsdictCompute the model output
Return the task result key (e.g.
"imputation")When
calc_criterion=True, also return"loss"(training) or"metric"(validation)
Here is a complete example based on the SAITS pattern:
# pypots/imputation/your_model/core.py
import torch.nn as nn
from ...nn.modules import ModelCore
from ...nn.modules.loss import Criterion
class _YourModel(ModelCore):
def __init__(
self,
n_steps: int,
n_features: int,
d_model: int,
training_loss: Criterion,
validation_metric: Criterion,
):
super().__init__()
self.training_loss = training_loss
if validation_metric.__class__.__name__ == "Criterion":
self.validation_metric = self.training_loss
else:
self.validation_metric = validation_metric
# Define your model's components
self.embedding = nn.Linear(n_features, d_model)
self.backbone = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=d_model * 4),
num_layers=2,
)
self.output_proj = nn.Linear(d_model, n_features)
def forward(self, inputs: dict, calc_criterion: bool = False) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]
# Model computation
embedded = self.embedding(X)
encoded = self.backbone(embedded)
reconstruction = self.output_proj(encoded)
# Combine: keep observed values, fill missing with model output
imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
results = {
"imputation": imputed_data,
"reconstruction": reconstruction,
}
# Loss / metric computation
if calc_criterion:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
if self.training:
# Training: return "loss" for backpropagation
results["loss"] = self.training_loss(
reconstruction, X_ori, indicating_mask
)
else:
# Validation: return "metric" for model selection
results["metric"] = self.validation_metric(
reconstruction, X_ori, indicating_mask
)
return results
Step 3: Implement model.py¶
Your wrapper owns orchestration. For a standard NN model, it should do five jobs:
Inherit the correct task NN base
Instantiate the core model
Instantiate and initialize the optimizer
Implement
_assemble_input_for_training(),_assemble_input_for_validating(), and_assemble_input_for_testing()Build datasets and dataloaders in
fit(), then call_train_model()
Here is a complete example:
# pypots/imputation/your_model/model.py
from typing import Union, Optional
import numpy as np
import torch
from torch.utils.data import DataLoader
from .core import _YourModel
from ..base import BaseNNImputer
from ...data.dataset.base import BaseDataset
from ...data.checking import key_in_data_set
from ...nn.modules.loss import Criterion, MAE, MSE
from ...optim.adam import Adam
from ...optim.base import Optimizer
class YourModel(BaseNNImputer):
"""Your model description here.
Parameters
----------
n_steps :
The number of time steps in the time-series data sample.
n_features :
The number of features in the time-series data sample.
d_model :
The dimension of the model's backbone.
batch_size :
The batch size for training and evaluating the model.
epochs :
The number of epochs for training the model.
patience :
The patience for the early-stopping mechanism.
training_loss :
The loss function for training. Default: MAE.
validation_metric :
The metric function for validation. Default: MSE.
optimizer :
The optimizer for model training. Default: Adam.
"""
def __init__(
self,
n_steps: int,
n_features: int,
d_model: int = 64,
batch_size: int = 32,
epochs: int = 100,
patience: Optional[int] = None,
training_loss: Union[Criterion, type] = MAE,
validation_metric: Union[Criterion, type] = MSE,
optimizer: Union[Optimizer, type] = Adam,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: Optional[str] = None,
model_saving_strategy: Optional[str] = "best",
verbose: bool = True,
):
super().__init__(
training_loss=training_loss,
validation_metric=validation_metric,
batch_size=batch_size,
epochs=epochs,
patience=patience,
num_workers=num_workers,
device=device,
saving_path=saving_path,
model_saving_strategy=model_saving_strategy,
verbose=verbose,
)
# Store hyperparameters
self.n_steps = n_steps
self.n_features = n_features
self.d_model = d_model
# Set up the model
self.model = _YourModel(
n_steps=n_steps,
n_features=n_features,
d_model=d_model,
training_loss=self.training_loss,
validation_metric=self.validation_metric,
)
self._print_model_size()
self._send_model_to_given_device()
# Set up the optimizer
if isinstance(optimizer, Optimizer):
self.optimizer = optimizer
else:
self.optimizer = optimizer()
assert isinstance(self.optimizer, Optimizer)
self.optimizer.init_optimizer(self.model.parameters())
def _assemble_input_for_training(self, data: list) -> dict:
(
indices,
X,
missing_mask,
X_ori,
indicating_mask,
) = self._send_data_to_given_device(data)
inputs = {
"X": X,
"missing_mask": missing_mask,
"X_ori": X_ori,
"indicating_mask": indicating_mask,
}
return inputs
def _assemble_input_for_validating(self, data: list) -> dict:
return self._assemble_input_for_training(data)
def _assemble_input_for_testing(self, data: list) -> dict:
indices, X, missing_mask = self._send_data_to_given_device(data)
inputs = {
"X": X,
"missing_mask": missing_mask,
}
return inputs
def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "hdf5",
) -> None:
# Step 1: Create datasets and dataloaders
training_set = DatasetForYourModel(
train_set, return_X_ori=False, return_y=False, file_type=file_type
)
train_dataloader = DataLoader(
training_set,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
val_dataloader = None
if val_set is not None:
if not key_in_data_set("X_ori", val_set):
raise ValueError("val_set must contain 'X_ori' for validation.")
val_dataset = DatasetForYourModel(
val_set, return_X_ori=True, return_y=False, file_type=file_type
)
val_dataloader = DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
# Step 2: Train the model
self._train_model(train_dataloader, val_dataloader)
self.model.load_state_dict(self.best_model_dict)
# Step 3: Auto-save if configured
self._auto_save_model_if_necessary(
confirm_saving=self.model_saving_strategy == "best"
)
def predict(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> dict:
result_dict = super().predict(test_set, file_type)
return result_dict
def impute(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
results = super().impute(test_set, file_type)
return results
Step 4: Add data.py Only If Needed¶
Add data.py only when BaseDataset cannot express your model’s sample contract.
SAITS needs data.py because masked-imputation training requires artificial masking
that BaseDataset does not provide.
If your model can work with BaseDataset directly (or reuse another model’s dataset
like DatasetForBRITS), do not add extra dataset code.
Step 5: Wire the Package¶
Create the __init__.py to export your model:
# pypots/imputation/your_model/__init__.py
from .model import YourModel
__all__ = ["YourModel"]
Then add the import to the task package’s __init__.py:
# In pypots/imputation/__init__.py, add:
from .your_model import YourModel
Step 6: Keep predict() Boring¶
The best predict() is usually a thin wrapper over the task base implementation.
SAITS.predict() is a good example — it keeps the public API explicit,
passes through inference-time options, and reuses BaseNNImputer.predict() for the actual loop.
SAITS Walkthrough Summary¶
Read SAITS in this order for the full picture:
model.py: wrapper, optimizer, dataloaders, input assemblycore.py: forward contract and loss/metric outputsdata.py: why a custom dataset exists
Key things to copy from SAITS:
Wrapper and core responsibilities stay separate
Stage-specific input assembly is explicit
Validation requirements are checked early (e.g.
X_orimust exist inval_set)Best-checkpoint loading happens after training
Definition of Done¶
Your standard NN integration is done when all of these are true:
fit()runs without overriding_train_model()Training returns
"loss"and validation returns"metric"predict()returns the correct task result key and shapeSave/load still works
Targeted task tests pass
If you keep fighting the default training loop, you are probably no longer on the standard path. Switch to Complex NN Integration Path.