Source code for pypots.nn.modules.timesnet.layers

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F

from ..inception import InceptionBlockV1


def FFT_for_Period(x, k=2):
    # [B, T, C]
    xf = torch.fft.rfft(x, dim=1)
    # find period by amplitudes
    frequency_list = abs(xf).mean(0).mean(-1)
    frequency_list[0] = 0
    _, top_list = torch.topk(frequency_list, k)
    top_list = top_list.detach().cpu().numpy()
    period = x.shape[1] // top_list
    return period, abs(xf).mean(-1)[:, top_list]


[docs] class TimesBlock(nn.Module): def __init__(self, seq_len, pred_len, top_k, d_model, d_ffn, num_kernels): super().__init__() self.seq_len = seq_len self.pred_len = pred_len self.top_k = top_k # parameter-efficient design self.conv = nn.Sequential( InceptionBlockV1(d_model, d_ffn, num_kernels=num_kernels), nn.GELU(), InceptionBlockV1(d_ffn, d_model, num_kernels=num_kernels), )
[docs] def forward(self, x): B, T, N = x.size() period_list, period_weight = FFT_for_Period(x, self.top_k) res = [] for i in range(self.top_k): period = period_list[i] # padding if (self.seq_len + self.pred_len) % period != 0: length = (((self.seq_len + self.pred_len) // period) + 1) * period padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device) out = torch.cat([x, padding], dim=1) else: length = self.seq_len + self.pred_len out = x # reshape out = out.reshape(B, length // period, period, N).permute(0, 3, 1, 2).contiguous() # 2D conv: from 1d Variation to 2d Variation out = self.conv(out) # reshape back out = out.permute(0, 2, 3, 1).reshape(B, -1, N) res.append(out[:, : (self.seq_len + self.pred_len), :]) res = torch.stack(res, dim=-1) # adaptive aggregation period_weight = F.softmax(period_weight, dim=1) period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1) res = torch.sum(res * period_weight, -1) # residual connection res = res + x return res