Source code for pypots.nn.modules.autoformer.layers

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import math
from typing import Tuple, Optional

import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F

from ..transformer.attention import AttentionOperator, MultiHeadAttention


[docs] class AutoCorrelation(AttentionOperator): """ AutoCorrelation Mechanism with the following two phases: (1) period-based dependencies discovery (2) time delay aggregation This block can replace the self-attention family mechanism seamlessly. """ def __init__( self, factor=1, attention_dropout=0.1, ): super().__init__() self.factor = factor self.dropout = nn.Dropout(attention_dropout)
[docs] def time_delay_agg_training(self, values, corr): """ SpeedUp version of Autocorrelation (a batch-normalization style design) This is for the training phase. """ head = values.shape[1] channel = values.shape[2] length = values.shape[3] # find top k top_k = int(self.factor * math.log(length)) mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) # update corr tmp_corr = torch.softmax(weights, dim=-1) # aggregation tmp_values = values delays_agg = torch.zeros_like(values).float() for i in range(top_k): pattern = torch.roll(tmp_values, -int(index[i]), -1) delays_agg = delays_agg + pattern * ( tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) ) return delays_agg
[docs] def time_delay_agg_inference(self, values, corr): """ SpeedUp version of Autocorrelation (a batch-normalization style design) This is for the inference phase. """ batch = values.shape[0] head = values.shape[1] channel = values.shape[2] length = values.shape[3] # index init init_index = ( torch.arange(length) .unsqueeze(0) .unsqueeze(0) .unsqueeze(0) .repeat(batch, head, channel, 1) .to(values.device) ) # find top k top_k = int(self.factor * math.log(length)) mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) weights, delay = torch.topk(mean_value, top_k, dim=-1) # update corr tmp_corr = torch.softmax(weights, dim=-1) # aggregation tmp_values = values.repeat(1, 1, 1, 2) delays_agg = torch.zeros_like(values).float() for i in range(top_k): tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) delays_agg = delays_agg + pattern * ( tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) ) return delays_agg
[docs] def time_delay_agg_full(self, values, corr): """ Standard version of Autocorrelation """ batch = values.shape[0] head = values.shape[1] channel = values.shape[2] length = values.shape[3] # index init init_index = ( torch.arange(length) .unsqueeze(0) .unsqueeze(0) .unsqueeze(0) .repeat(batch, head, channel, 1) .to(values.device) ) # find top k top_k = int(self.factor * math.log(length)) weights, delay = torch.topk(corr, top_k, dim=-1) # update corr tmp_corr = torch.softmax(weights, dim=-1) # aggregation tmp_values = values.repeat(1, 1, 1, 2) delays_agg = torch.zeros_like(values).float() for i in range(top_k): tmp_delay = init_index + delay[..., i].unsqueeze(-1) pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) return delays_agg
[docs] def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: # q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor] # d_tensor could be d_q, d_k, d_v B, L, H, E = q.shape _, S, _, D = v.shape if L > S: zeros = torch.zeros_like(q[:, : (L - S), :]).float() v = torch.cat([v, zeros], dim=1) k = torch.cat([k, zeros], dim=1) else: v = v[:, :L, :, :] k = k[:, :L, :, :] # period-based dependencies q_fft = torch.fft.rfft(q.permute(0, 2, 3, 1).contiguous(), dim=-1) k_fft = torch.fft.rfft(k.permute(0, 2, 3, 1).contiguous(), dim=-1) res = q_fft * torch.conj(k_fft) corr = torch.fft.irfft(res, dim=-1) # time delay agg if self.training: V = self.time_delay_agg_training(v.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) else: V = self.time_delay_agg_inference(v.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) attn = corr.permute(0, 3, 1, 2) output = V.contiguous() return output, attn
[docs] class SeasonalLayerNorm(nn.Module): """A special designed layer normalization for the seasonal part.""" def __init__(self, n_channels): super().__init__() self.layer_norm = nn.LayerNorm(n_channels)
[docs] def forward(self, x): x_hat = self.layer_norm(x) bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) return x_hat - bias
[docs] class MovingAvgBlock(nn.Module): """ The moving average block to highlight the trend of time series. """ def __init__(self, kernel_size, stride): super().__init__() self.kernel_size = kernel_size self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
[docs] def forward(self, x): # padding on the both ends of time series front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) x = torch.cat([front, x, end], dim=1) x = self.avg(x.permute(0, 2, 1)) x = x.permute(0, 2, 1) return x
[docs] class SeriesDecompositionBlock(nn.Module): """ Series decomposition block """ def __init__(self, kernel_size): super().__init__() self.moving_avg = MovingAvgBlock(kernel_size, stride=1)
[docs] def forward(self, x): moving_mean = self.moving_avg(x) res = x - moving_mean return res, moving_mean
[docs] class AutoformerEncoderLayer(nn.Module): """Autoformer encoder layer with the progressive decomposition architecture.""" def __init__( self, attn_opt: AttentionOperator, d_model: int, n_heads: int, d_ffn: int, moving_avg: int = 25, dropout: float = 0.1, activation="relu", ): super().__init__() d_ffn = d_ffn or 4 * d_model self.attention = MultiHeadAttention( attn_opt, d_model, n_heads, d_model // n_heads, d_model // n_heads, ) self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ffn, kernel_size=1, bias=False) self.conv2 = nn.Conv1d(in_channels=d_ffn, out_channels=d_model, kernel_size=1, bias=False) self.series_decomp1 = SeriesDecompositionBlock(moving_avg) self.series_decomp2 = SeriesDecompositionBlock(moving_avg) self.dropout = nn.Dropout(dropout) self.activation = F.relu if activation == "relu" else F.gelu
[docs] def forward(self, x, attn_mask=None): new_x, attn = self.attention(x, x, x, attn_mask=attn_mask) x = x + self.dropout(new_x) x, _ = self.series_decomp1(x) y = x y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) y = self.dropout(self.conv2(y).transpose(-1, 1)) res, _ = self.series_decomp2(x + y) return res, attn
[docs] class AutoformerDecoderLayer(nn.Module): """ Autoformer decoder layer with the progressive decomposition architecture """ def __init__( self, self_attn_opt, cross_attn_opt, d_model, n_heads, d_out, d_ff=None, moving_avg=25, dropout=0.1, activation="relu", ): super().__init__() d_ff = d_ff or 4 * d_model self.self_attention = MultiHeadAttention( self_attn_opt, d_model, n_heads, d_model // n_heads, d_model // n_heads, ) self.cross_attention = MultiHeadAttention( cross_attn_opt, d_model, n_heads, d_model // n_heads, d_model // n_heads, ) self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) self.series_decomp1 = SeriesDecompositionBlock(moving_avg) self.series_decomp2 = SeriesDecompositionBlock(moving_avg) self.series_decomp3 = SeriesDecompositionBlock(moving_avg) self.dropout = nn.Dropout(dropout) self.projection = nn.Conv1d( in_channels=d_model, out_channels=d_out, kernel_size=3, stride=1, padding=1, padding_mode="circular", bias=False, ) self.activation = F.relu if activation == "relu" else F.gelu
[docs] def forward(self, x, cross, x_mask=None, cross_mask=None): x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0]) x, trend1 = self.series_decomp1(x) x = x + self.dropout(self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]) x, trend2 = self.series_decomp2(x) y = x y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) y = self.dropout(self.conv2(y).transpose(-1, 1)) x, trend3 = self.series_decomp3(x + y) residual_trend = trend1 + trend2 + trend3 residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) return x, residual_trend