"""
The implementation of HELIX for the partially-observed time-series imputation task.
"""
# Created by Fengming Zhang <milaogou@gmail.com>
# License: BSD-3-Clause
from typing import Union, Optional
import torch
from torch.utils.data import DataLoader
from .core import _HELIX
from ..base import BaseNNImputer
from ..saits.data import DatasetForSAITS
from ...data.checking import key_in_data_set
from ...nn.modules.loss import Criterion, MAE, MSE
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger
[docs]
class HELIX(BaseNNImputer):
"""The PyTorch implementation of the HELIX: Hybrid Encoding with Learnable Identity
and Cross-dimensional Synthesis for Time Series Imputation :cite:`zhang2026helix`.
Parameters
----------
n_steps :
The number of time steps in the time-series data sample.
n_features :
The number of features in the time-series data sample.
d_pe :
The dimension of the positional encoding for temporal dimension.
Total embedding dimension will be pe_dim + feature_embed_dim + 2 (data + temporal_pe + feature_id + mask).
d_feature_embed :
The dimension of the learnable feature identity embedding.
d_model :
The dimension of the model's hidden states.
n_heads :
The number of attention heads.
``d_model`` must be divisible by ``n_heads``.
n_layers :
The number of hybrid encoding layers.
dropout :
The dropout rate for all layers.
ORT_weight :
The weight for the Observed Reconstruction Task (ORT) loss.
MIT_weight :
The weight for the Masked Imputation Task (MIT) loss.
batch_size :
The batch size for training and evaluating the model.
epochs :
The number of epochs for training the model.
patience :
The patience for the early-stopping mechanism. Given a positive integer, the training process will be
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.
training_loss :
The customized loss function designed by users for training the model.
If not given, will use MAE as default.
validation_metric :
The customized metric function designed by users for validating the model.
If not given, will use MSE as default.
optimizer :
The optimizer for model training.
If not given, will use a default Adam optimizer.
num_workers :
The number of subprocesses to use for data loading.
`0` means data loading will be in the main process.
device :
The device for the model to run on.
saving_path :
The path for automatically saving model checkpoints and tensorboard files.
model_saving_strategy :
The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
verbose :
Whether to print out the training logs during the training process.
"""
def __init__(
self,
n_steps: int,
n_features: int,
d_pe: int = 16,
d_feature_embed: int = 1,
d_model: int = 256,
n_heads: int = 8,
n_layers: int = 2,
dropout: float = 0.1,
ORT_weight: float = 1.0,
MIT_weight: float = 1.0,
batch_size: int = 32,
epochs: int = 100,
patience: Optional[int] = None,
training_loss: Union[Criterion, type] = MAE,
validation_metric: Union[Criterion, type] = MSE,
optimizer: Union[Optimizer, type] = Adam,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: Optional[str] = None,
model_saving_strategy: Optional[str] = "best",
verbose: bool = True,
):
super().__init__(
training_loss=training_loss,
validation_metric=validation_metric,
batch_size=batch_size,
epochs=epochs,
patience=patience,
num_workers=num_workers,
device=device,
saving_path=saving_path,
model_saving_strategy=model_saving_strategy,
verbose=verbose,
)
# Check d_model divisibility
if d_model % n_heads != 0:
logger.warning(f"‼️ d_model ({d_model}) must be divisible by n_heads ({n_heads})")
d_model = n_heads * (d_model // n_heads)
logger.warning(f"⚠️ d_model is adjusted to {d_model}")
self.n_steps = n_steps
self.n_features = n_features
self.d_pe = d_pe
self.d_feature_embed = d_feature_embed
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.dropout = dropout
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight
# Set up the model
self.model = _HELIX(
n_steps=n_steps,
n_features=n_features,
d_pe=d_pe,
d_feature_embed=d_feature_embed,
d_model=d_model,
n_heads=n_heads,
n_layers=n_layers,
dropout=dropout,
ORT_weight=ORT_weight,
MIT_weight=MIT_weight,
training_loss=self.training_loss,
validation_metric=self.validation_metric,
)
self._print_model_size()
self._send_model_to_given_device()
# Set up the optimizer
if isinstance(optimizer, Optimizer):
self.optimizer = optimizer
else:
self.optimizer = optimizer()
assert isinstance(self.optimizer, Optimizer)
self.optimizer.init_optimizer(self.model.parameters())
def _assemble_input_for_training(self, data: list) -> dict:
"""Assemble input data for training."""
indices, X, missing_mask, X_ori, indicating_mask = self._send_data_to_given_device(data)
inputs = {
"X": X,
"missing_mask": missing_mask,
"X_ori": X_ori,
"indicating_mask": indicating_mask,
}
return inputs
def _assemble_input_for_validating(self, data: list) -> dict:
"""Assemble input data for validation."""
return self._assemble_input_for_training(data)
def _assemble_input_for_testing(self, data: list) -> dict:
"""Assemble input data for testing."""
indices, X, missing_mask = self._send_data_to_given_device(data)
inputs = {
"X": X,
"missing_mask": missing_mask,
}
return inputs
[docs]
def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "hdf5",
) -> None:
"""Train the HELIX model.
Parameters
----------
train_set :
The training dataset.
val_set :
The validation dataset.
file_type :
The type of the data file if train_set/val_set are file paths.
"""
# Create datasets
train_dataset = DatasetForSAITS(train_set, return_X_ori=False, return_y=False, file_type=file_type)
train_dataloader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
val_dataloader = None
if val_set is not None:
if not key_in_data_set("X_ori", val_set):
raise ValueError("val_set must contain 'X_ori' for model validation.")
val_dataset = DatasetForSAITS(val_set, return_X_ori=True, return_y=False, file_type=file_type)
val_dataloader = DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
# Train the model with LR scheduling
self._train_model(train_dataloader, val_dataloader)
self.model.load_state_dict(self.best_model_dict)
# Save the model
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
[docs]
@torch.no_grad()
def predict(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> dict:
"""Make predictions for the input data with the trained model.
Parameters
----------
test_set :
The dataset for testing.
file_type :
The type of the given file if test_set is a path string.
Returns
-------
result_dict :
The dictionary containing the imputation results.
"""
result_dict = super().predict(test_set, file_type)
return result_dict