Data Flow and Dataset Contracts

Most integration bugs in PyPOTS happen at one boundary:

Dataset sample list → Wrapper input dict → Core forward()

This page makes that boundary explicit.

The Normal Data Flow

┌─────────────────────────────────────────────────────────────────────┐
│  1. Dataset.__getitem__(idx)   → returns a sample list             │
│  2. DataLoader                 → batches those lists               │
│  3. Wrapper._assemble_input_*  → turns list into a dict           │
│  4. Core.forward(inputs)       → reads the dict, does computation  │
│  5. Wrapper / task base        → returns task-level result         │
└─────────────────────────────────────────────────────────────────────┘

If one stage uses the wrong keys or the wrong shape, the failure usually appears at the boundary between stages 3 and 4.

What BaseDataset Returns

BaseDataset (in pypots/data/dataset/base.py) always starts from the same base items:

[idx, X, missing_mask]

Then it appends optional items based on flags:

Flag

Extra Items

Why They Exist

return_X_ori=True

X_ori, indicating_mask

For models that need original targets or artificial-missing masks

return_X_pred=True

X_pred, X_pred_missing_mask

For forecasting targets

return_y=True

y

For classification or other supervised outputs

Array input and file-backed input follow the same logical order.

What Each Stage Should Assemble

The three assembly functions exist because train, validation, and test do not always need the same tensors.

Typical pattern:

  • Training: include everything needed for loss computation

  • Validation: include everything needed for metric computation

  • Testing: keep only inference-time inputs

This is why one model may need more tensors in fit() than in predict().

SAITS: A Concrete Data Flow Example

SAITS is the clearest reference for a standard NN model with extra data requirements.

Custom Dataset: DatasetForSAITS

DatasetForSAITS extends BaseDataset because the default dataset is not enough. It introduces Masked Imputation Training (MIT) — artificially masking a portion of observed values to create a self-supervised training signal.

from pypots.data.dataset.base import BaseDataset
from pygrinder import mcar

class DatasetForSAITS(BaseDataset):
    def __init__(self, data, return_X_ori, return_y, file_type="hdf5", rate=0.2):
        super().__init__(data, return_X_ori=return_X_ori, return_y=return_y,
                         file_type=file_type)
        self.rate = rate  # artificial masking rate for MIT

    def _fetch_data_from_array(self, idx):
        # Get original data
        X = self.X[idx]
        missing_mask = self.missing_mask[idx]
        X_ori = self.X_ori[idx]
        indicating_mask = self.indicating_mask[idx]

        # Apply additional artificial masking for MIT
        X_hat, missing_mask_hat = mcar(X, rate=self.rate)
        indicating_mask_hat = missing_mask - missing_mask_hat

        return (
            idx,
            X_hat,             # masked input fed to model
            missing_mask_hat,  # observed vs missing in X_hat
            X_ori,             # original targets for loss
            indicating_mask_hat # artificially hidden positions
        )

This shows a clean reason to add a custom dataset: training needs artificial masking, validation needs original targets, testing should stay minimal.

Training and Validation Assembly

SAITS._assemble_input_for_training() builds this dict:

def _assemble_input_for_training(self, data: list) -> dict:
    (
        indices,
        X,
        missing_mask,
        X_ori,
        indicating_mask,
    ) = self._send_data_to_given_device(data)

    inputs = {
        "X": X,
        "missing_mask": missing_mask,
        "X_ori": X_ori,
        "indicating_mask": indicating_mask,
    }
    return inputs

Inside _SAITS.forward(), X and missing_mask are always used for the model pass. When calc_criterion=True, X_ori and indicating_mask are used to produce training loss or validation metric.

Testing Assembly

SAITS._assemble_input_for_testing() intentionally drops the extra tensors:

def _assemble_input_for_testing(self, data: list) -> dict:
    indices, X, missing_mask = self._send_data_to_given_device(data)

    inputs = {
        "X": X,
        "missing_mask": missing_mask,
    }
    return inputs

That is the contract to remember:

  • Training/validation need X_ori and indicating_mask

  • Testing does not

The Complete Forward Flow

Training:
DatasetForSAITS → [idx, X, mask, X_ori, indicating_mask]
  → _assemble_input_for_training() → {"X", "missing_mask", "X_ori", "indicating_mask"}
    → _SAITS.forward(inputs, calc_criterion=True)
      → returns {"imputation", "loss", ...}

Validation:
DatasetForSAITS → [idx, X, mask, X_ori, indicating_mask]
  → _assemble_input_for_validating() → {"X", "missing_mask", "X_ori", "indicating_mask"}
    → _SAITS.forward(inputs, calc_criterion=True)
      → returns {"imputation", "metric", ...}

Testing (inference):
BaseDataset → [idx, X, missing_mask]
  → _assemble_input_for_testing() → {"X", "missing_mask"}
    → _SAITS.forward(inputs, calc_criterion=False)
      → returns {"imputation", ...}

When You Need a Custom Dataset

Add data.py only when BaseDataset cannot express your model’s sample contract.

Good reasons to create a custom dataset:

  • You need artificial masking like SAITS

  • You need extra stage-dependent tensors

  • File-mode loading needs special handling

  • The default sample order does not cover your model

Bad reason:

  • You only want to rename keys that the wrapper could assemble directly

Most models in PyPOTS do not need a custom dataset. Check if BaseDataset can handle your requirements first.

Data Input Formats

PyPOTS supports two input modes for all models:

Dict input (in-memory):

train_set = {
    "X": np.array(...),        # shape: [n_samples, n_steps, n_features]
    "y": np.array(...),        # shape: [n_samples] (optional)
}
val_set = {
    "X": np.array(...),
    "X_ori": np.array(...),    # original data for validation metric
    "y": np.array(...),
}

File input (lazy-loading from HDF5):

from pypots.data.saving import save_dict_into_h5

# Save data to HDF5 files
save_dict_into_h5(train_set, "train_set.h5")
save_dict_into_h5(val_set, "val_set.h5")

# Use file paths instead of dicts
model.fit("train_set.h5", "val_set.h5")
results = model.predict("test_set.h5")

Both modes follow the same logical order. Make sure your model handles both consistently.

Fast Debugging Checklist

Before changing model math, check these first:

  1. Print the sample length and item order from the dataset

  2. Print the dict keys right before forward()

  3. Print tensor shapes for train, validation, and test separately

  4. Verify array input and file input follow the same contract

  5. Verify masks mean the same thing in dataset code and model code

# Quick debugging snippet for your _assemble_input_for_training:
def _assemble_input_for_training(self, data: list) -> dict:
    print(f"Number of items in data: {len(data)}")
    for i, item in enumerate(data):
        if hasattr(item, 'shape'):
            print(f"  data[{i}]: shape={item.shape}, dtype={item.dtype}")
        else:
            print(f"  data[{i}]: type={type(item)}")
    # ... your actual assembly logic

In PyPOTS, many “model bugs” are really data-contract bugs.