Source code for pypots.nn.modules.timesnet.backbone
""" """
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import torch
import torch.nn as nn
from .layers import TimesBlock
[docs]
class BackboneTimesNet(nn.Module):
def __init__(
self,
n_layers,
n_steps,
n_pred_steps,
top_k,
d_model,
d_ffn,
n_kernels,
):
super().__init__()
self.seq_len = n_steps
self.n_layers = n_layers
self.n_pred_steps = n_pred_steps
self.model = nn.ModuleList(
[TimesBlock(n_steps, n_pred_steps, top_k, d_model, d_ffn, n_kernels) for _ in range(n_layers)]
)
self.layer_norm = nn.LayerNorm(d_model)
[docs]
def forward(self, X) -> torch.Tensor:
enc_out = X
for i in range(self.n_layers):
enc_out = self.layer_norm(self.model[i](enc_out))
return enc_out