Data Flow and Dataset Contracts¶
Most integration bugs in PyPOTS happen at one boundary:
Dataset sample list → Wrapper input dict → Core forward()
This page makes that boundary explicit.
The Normal Data Flow¶
┌─────────────────────────────────────────────────────────────────────┐
│ 1. Dataset.__getitem__(idx) → returns a sample list │
│ 2. DataLoader → batches those lists │
│ 3. Wrapper._assemble_input_* → turns list into a dict │
│ 4. Core.forward(inputs) → reads the dict, does computation │
│ 5. Wrapper / task base → returns task-level result │
└─────────────────────────────────────────────────────────────────────┘
If one stage uses the wrong keys or the wrong shape, the failure usually appears at the boundary between stages 3 and 4.
What BaseDataset Returns¶
BaseDataset (in pypots/data/dataset/base.py) always starts from the same base items:
[idx, X, missing_mask]
Then it appends optional items based on flags:
Flag |
Extra Items |
Why They Exist |
|---|---|---|
|
|
For models that need original targets or artificial-missing masks |
|
|
For forecasting targets |
|
|
For classification or other supervised outputs |
Array input and file-backed input follow the same logical order.
What Each Stage Should Assemble¶
The three assembly functions exist because train, validation, and test do not always need the same tensors.
Typical pattern:
Training: include everything needed for loss computation
Validation: include everything needed for metric computation
Testing: keep only inference-time inputs
This is why one model may need more tensors in fit() than in predict().
SAITS: A Concrete Data Flow Example¶
SAITS is the clearest reference for a standard NN model with extra data requirements.
Custom Dataset: DatasetForSAITS¶
DatasetForSAITS extends BaseDataset because the default dataset is not enough.
It introduces Masked Imputation Training (MIT) — artificially masking a portion of
observed values to create a self-supervised training signal.
from pypots.data.dataset.base import BaseDataset
from pygrinder import mcar
class DatasetForSAITS(BaseDataset):
def __init__(self, data, return_X_ori, return_y, file_type="hdf5", rate=0.2):
super().__init__(data, return_X_ori=return_X_ori, return_y=return_y,
file_type=file_type)
self.rate = rate # artificial masking rate for MIT
def _fetch_data_from_array(self, idx):
# Get original data
X = self.X[idx]
missing_mask = self.missing_mask[idx]
X_ori = self.X_ori[idx]
indicating_mask = self.indicating_mask[idx]
# Apply additional artificial masking for MIT
X_hat, missing_mask_hat = mcar(X, rate=self.rate)
indicating_mask_hat = missing_mask - missing_mask_hat
return (
idx,
X_hat, # masked input fed to model
missing_mask_hat, # observed vs missing in X_hat
X_ori, # original targets for loss
indicating_mask_hat # artificially hidden positions
)
This shows a clean reason to add a custom dataset: training needs artificial masking, validation needs original targets, testing should stay minimal.
Training and Validation Assembly¶
SAITS._assemble_input_for_training() builds this dict:
def _assemble_input_for_training(self, data: list) -> dict:
(
indices,
X,
missing_mask,
X_ori,
indicating_mask,
) = self._send_data_to_given_device(data)
inputs = {
"X": X,
"missing_mask": missing_mask,
"X_ori": X_ori,
"indicating_mask": indicating_mask,
}
return inputs
Inside _SAITS.forward(), X and missing_mask are always used for the model pass.
When calc_criterion=True, X_ori and indicating_mask are used to produce
training loss or validation metric.
Testing Assembly¶
SAITS._assemble_input_for_testing() intentionally drops the extra tensors:
def _assemble_input_for_testing(self, data: list) -> dict:
indices, X, missing_mask = self._send_data_to_given_device(data)
inputs = {
"X": X,
"missing_mask": missing_mask,
}
return inputs
That is the contract to remember:
Training/validation need
X_oriandindicating_maskTesting does not
The Complete Forward Flow¶
Training:
DatasetForSAITS → [idx, X, mask, X_ori, indicating_mask]
→ _assemble_input_for_training() → {"X", "missing_mask", "X_ori", "indicating_mask"}
→ _SAITS.forward(inputs, calc_criterion=True)
→ returns {"imputation", "loss", ...}
Validation:
DatasetForSAITS → [idx, X, mask, X_ori, indicating_mask]
→ _assemble_input_for_validating() → {"X", "missing_mask", "X_ori", "indicating_mask"}
→ _SAITS.forward(inputs, calc_criterion=True)
→ returns {"imputation", "metric", ...}
Testing (inference):
BaseDataset → [idx, X, missing_mask]
→ _assemble_input_for_testing() → {"X", "missing_mask"}
→ _SAITS.forward(inputs, calc_criterion=False)
→ returns {"imputation", ...}
When You Need a Custom Dataset¶
Add data.py only when BaseDataset cannot express your model’s sample contract.
Good reasons to create a custom dataset:
You need artificial masking like SAITS
You need extra stage-dependent tensors
File-mode loading needs special handling
The default sample order does not cover your model
Bad reason:
You only want to rename keys that the wrapper could assemble directly
Most models in PyPOTS do not need a custom dataset.
Check if BaseDataset can handle your requirements first.
Data Input Formats¶
PyPOTS supports two input modes for all models:
Dict input (in-memory):
train_set = {
"X": np.array(...), # shape: [n_samples, n_steps, n_features]
"y": np.array(...), # shape: [n_samples] (optional)
}
val_set = {
"X": np.array(...),
"X_ori": np.array(...), # original data for validation metric
"y": np.array(...),
}
File input (lazy-loading from HDF5):
from pypots.data.saving import save_dict_into_h5
# Save data to HDF5 files
save_dict_into_h5(train_set, "train_set.h5")
save_dict_into_h5(val_set, "val_set.h5")
# Use file paths instead of dicts
model.fit("train_set.h5", "val_set.h5")
results = model.predict("test_set.h5")
Both modes follow the same logical order. Make sure your model handles both consistently.
Fast Debugging Checklist¶
Before changing model math, check these first:
Print the sample length and item order from the dataset
Print the dict keys right before
forward()Print tensor shapes for train, validation, and test separately
Verify array input and file input follow the same contract
Verify masks mean the same thing in dataset code and model code
# Quick debugging snippet for your _assemble_input_for_training:
def _assemble_input_for_training(self, data: list) -> dict:
print(f"Number of items in data: {len(data)}")
for i, item in enumerate(data):
if hasattr(item, 'shape'):
print(f" data[{i}]: shape={item.shape}, dtype={item.dtype}")
else:
print(f" data[{i}]: type={type(item)}")
# ... your actual assembly logic
In PyPOTS, many “model bugs” are really data-contract bugs.