Source code for pypots.nn.modules.etsformer.autoencoder
""" """
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import torch.nn as nn
[docs]
class ETSformerEncoder(nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = nn.ModuleList(layers)
[docs]
def forward(self, res, level, attn_mask=None):
growths = []
seasons = []
for layer in self.layers:
res, level, growth, season = layer(res, level, attn_mask=attn_mask)
growths.append(growth)
seasons.append(season)
return level, growths, seasons
[docs]
class ETSformerDecoder(nn.Module):
def __init__(self, layers):
super().__init__()
self.d_model = layers[0].d_model
self.d_out = layers[0].d_out
self.pred_len = layers[0].pred_len
self.n_head = layers[0].n_heads
self.layers = nn.ModuleList(layers)
self.pred = nn.Linear(self.d_model, self.d_out)
[docs]
def forward(self, growths, seasons):
growth_repr = []
season_repr = []
for idx, layer in enumerate(self.layers):
growth_horizon, season_horizon = layer(growths[idx], seasons[idx])
growth_repr.append(growth_horizon)
season_repr.append(season_horizon)
growth_repr = sum(growth_repr)
season_repr = sum(season_repr)
return self.pred(growth_repr), self.pred(season_repr)