Source code for pypots.nn.modules.informer.autoencoder

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch.nn as nn


[docs] class InformerEncoder(nn.Module): def __init__(self, attn_layers, conv_layers=None, norm_layer=None): super().__init__() self.attn_layers = nn.ModuleList(attn_layers) self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None self.norm = norm_layer
[docs] def forward(self, x, attn_mask=None): attns = [] if self.conv_layers is not None: for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): x, attn = attn_layer(x, attn_mask=attn_mask) x = conv_layer(x) attns.append(attn) x, attn = self.attn_layers[-1](x, attn_mask=attn_mask) attns.append(attn) else: for attn_layer in self.attn_layers: x, attn = attn_layer(x, attn_mask=attn_mask) attns.append(attn) if self.norm is not None: x = self.norm(x) return x, attns
[docs] class InformerDecoder(nn.Module): def __init__(self, layers, norm_layer=None, projection=None): super().__init__() self.layers = nn.ModuleList(layers) self.norm = norm_layer self.projection = projection
[docs] def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): for layer in self.layers: x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) trend = trend + residual_trend if self.norm is not None: x = self.norm(x) if self.projection is not None: x = self.projection(x) return x, trend