""" """# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-Clauseimporttorchimporttorch.nnasnnfrom..transformerimportPositionalEncoding
[docs]classSaitsEmbedding(nn.Module):"""The embedding method from the SAITS paper :cite:`du2023SAITS`. Parameters ---------- d_in : The input dimension. d_out : The output dimension. with_pos : Whether to add positional encoding. n_max_steps : The maximum number of steps. It only works when ``with_pos`` is True. dropout : The dropout rate. """def__init__(self,d_in:int,d_out:int,with_pos:bool,n_max_steps:int=1000,dropout:float=0,):super().__init__()self.with_pos=with_posself.dropout_rate=dropoutself.embedding_layer=nn.Linear(d_in,d_out)self.position_enc=PositionalEncoding(d_out,n_positions=n_max_steps)ifwith_poselseNoneself.dropout=nn.Dropout(p=dropout)ifdropout>0elseNone