Source code for pypots.nn.modules.segrnn.backbone

""" """

# Created by Shengsheng Lin

import torch
import torch.nn as nn


[docs] class BackboneSegRNN(nn.Module): def __init__( self, n_steps: int, n_features: int, n_pred_steps: int, seg_len: int = 24, d_model: int = 512, dropout: float = 0.5, ): super().__init__() if n_steps % seg_len: raise ValueError("The argument seg_len in SegRNN need to be divisible by the sequence length n_steps.") if n_pred_steps % seg_len: raise ValueError( "The argument seg_len in SegRNN need to be divisible by the prediction sequence length n_pred_steps." ) self.n_steps = n_steps self.n_features = n_features self.n_pred_steps = n_pred_steps self.seg_len = seg_len self.d_model = d_model self.dropout = dropout self.seg_num_x = self.n_steps // self.seg_len self.seg_num_y = self.n_pred_steps // self.seg_len self.valueEmbedding = nn.Sequential(nn.Linear(self.seg_len, self.d_model), nn.ReLU()) self.rnn = nn.GRU( input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True, batch_first=True, bidirectional=False, ) self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model // 2)) self.channel_emb = nn.Parameter(torch.randn(self.n_features, self.d_model // 2)) self.predict = nn.Sequential(nn.Dropout(self.dropout), nn.Linear(self.d_model, self.seg_len))
[docs] def forward(self, x): # b:batch_size c:channel_size s:seq_len s:seq_len # d:d_model w:seg_len n m:seg_num batch_size = x.size(0) # normalization and permute b,s,c -> b,c,s seq_last = x[:, -1:, :].detach() x = (x - seq_last).permute(0, 2, 1) # b,c,s # segment and embedding b,c,s -> bc,n,w -> bc,n,d x = self.valueEmbedding(x.reshape(-1, self.seg_num_x, self.seg_len)) # encoding _, hn = self.rnn(x) # bc,n,d 1,bc,d # m,d//2 -> 1,m,d//2 -> c,m,d//2 # c,d//2 -> c,1,d//2 -> c,m,d//2 # c,m,d -> cm,1,d -> bcm, 1, d pos_emb = ( torch.cat( [ self.pos_emb.unsqueeze(0).repeat(self.n_features, 1, 1), self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1), ], dim=-1, ) .view(-1, 1, self.d_model) .repeat(batch_size, 1, 1) ) _, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model)) # bcm,1,d 1,bcm,d # 1,bcm,d -> 1,bcm,w -> b,c,s y = self.predict(hy).view(-1, self.n_features, self.n_pred_steps) # permute and denorm y = y.permute(0, 2, 1) + seq_last return y