Source code for pypots.utils.random

"""
PyPOTS util module about random seed setting.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import random

import numpy as np
import torch

from .logging import logger

# Take the birth year of PyPOTS as the default random seed
RANDOM_SEED: int = 2022


[docs] def set_random_seed(random_seed: int = RANDOM_SEED) -> None: """Manually set the random state to make PyPOTS output reproducible results. Parameters ---------- random_seed : The seed to be set for generating random numbers in PyPOTS. """ globals()["RANDOM_SEED"] = random_seed random.seed(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) torch.cuda.manual_seed_all(random_seed) # torch.backends.cudnn.deterministic = True # This will slow down the training process. logger.info(f"Have set the random seed as {random_seed} for numpy and pytorch.")
[docs] def get_random_seed() -> int: """Get the random seed used in PyPOTS. Returns ------- random_seed : The random seed used in PyPOTS. """ return RANDOM_SEED