Quick-start ExamplesΒΆ

BrewPOTS logo

We put some examples here to help our users to get started quickly.

Please refer to BrewPOTS for detailed PyPOTS tutorials. You can also find a simple and quick-start tutorial notebook on Google Colab


import numpy as np
from sklearn.preprocessing import StandardScaler
from pygrinder import mcar, calc_missing_rate
from benchpots.datasets import preprocess_physionet2012

# prepare the dataset
data = preprocess_physionet2012(subset='set-a',rate=0.1) # Our ecosystem libs will automatically download and extract it
train_X, val_X, test_X = data["train_X"], data["val_X"], data["test_X"]
print(train_X.shape)  # (n_samples, n_steps, n_features)
print(val_X.shape)  # samples (n_samples) in train set and val set are different, but they have the same sequence len (n_steps) and feature dim (n_features)
print(f"We have {calc_missing_rate(train_X):.1%} values missing in train_X")

# organize the dataset for PyPOTS model input
train_set = {"X": train_X}  # in training set, simply put the incomplete time series into it
val_set = {
    "X": val_X,
    "X_ori": data["val_X_ori"],  # in validation set, we need ground truth for evaluation and picking the best model checkpoint
}
test_set = {"X": test_X}  # in test set, only give the testing incomplete time series for model to impute

# the test set for final evaluation
test_X_ori = data["test_X_ori"]  # test_X_ori bears ground truth for evaluation
indicating_mask = np.isnan(test_X) ^ np.isnan(test_X_ori)  # mask indicates the values that are missing in X but not in X_ori, i.e. where the gt values are

# initialize the model
_, n_steps, n_features = train_X.shape
saits = SAITS(
    n_steps=n_steps,
    n_features=n_features,
    n_layers=2,
    d_model=256,
    d_ffn=128,
    n_heads=4,
    d_k=64,
    d_v=64,
    dropout=0.1,
    epochs=10,
    saving_path="examples/saits", # set the path for saving tensorboard logging file and model checkpoint
    model_saving_strategy="best", # only save the model with the best validation performance
)

# train the model. You can also omit the val_set if you don't need to validate the model during training
saits.fit(train_set, val_set)
# impute the originally-missing values and artificially-missing values
imputation = saits.impute(test_set)
mae = calc_mae(imputation, np.nan_to_num(test_X_ori), indicating_mask)  # calculate mean absolute error on the ground truth (artificially-missing values)

# the best model has been already saved, but you can still manually save it with function save_model() as below
saits.save(saving_path="examples/saits/manually_saved_saits_model")
# you can load the saved model into a new initialized model
saits.load("examples/saits/manually_saved_saits_model.pypots")