Source code for pypots.nn.modules.informer.layers

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from math import sqrt
from typing import Optional

import numpy as np
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F

from ....nn.modules.transformer.attention import AttentionOperator


class ProbMask:
    def __init__(self, B, H, L, index, scores, device="cpu"):
        _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
        _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
        indicator = _mask_ex[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :].to(device)
        self._mask = indicator.view(scores.shape).to(device)

    @property
    def mask(self):
        return self._mask


[docs] class ConvLayer(nn.Module): def __init__(self, c_in): super().__init__() padding = 1 if torch.__version__ >= "1.5.0" else 2 self.downConv = nn.Conv1d( in_channels=c_in, out_channels=c_in, kernel_size=3, padding=padding, padding_mode="circular", ) self.norm = nn.BatchNorm1d(c_in) self.activation = nn.ELU() self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
[docs] def forward(self, x): x = self.downConv(x.permute(0, 2, 1)) x = self.norm(x) x = self.activation(x) x = self.maxPool(x) x = x.transpose(1, 2) return x
[docs] class ProbAttention(AttentionOperator): def __init__( self, mask_flag=True, factor=5, attention_dropout=0.1, scale=None, ): super().__init__() self.factor = factor self.scale = scale self.mask_flag = mask_flag self.dropout = nn.Dropout(attention_dropout) def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) # Q [B, H, L, D] B, H, L_K, E = K.shape _, _, L_Q, _ = Q.shape # calculate the sampled Q_K K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) # find the Top_k query with sparisty measurement M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) M_top = M.topk(n_top, sorted=False)[1] # use the reduced Q to calculate Q_K Q_reduce = Q[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :] # factor*ln(L_q) Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k return Q_K, M_top def _get_initial_context(self, V, L_Q): B, H, L_V, D = V.shape if not self.mask_flag: # V_sum = V.sum(dim=-2) V_sum = V.mean(dim=-2) contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() else: # use mask assert L_Q == L_V # requires that L_Q == L_V, i.e. for self-attention only contex = V.cumsum(dim=-2) return contex def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): B, H, L_V, D = V.shape if self.mask_flag: attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) scores.masked_fill_(attn_mask.mask, -np.inf) attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) context_in[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = torch.matmul( attn, V ).type_as(context_in) attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn return context_in, attns
[docs] def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, **kwargs, ): # q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor] # d_tensor could be d_q, d_k, d_v B, L_Q, H, D = q.shape _, L_K, _, _ = k.shape q = q.transpose(2, 1) k = k.transpose(2, 1) v = v.transpose(2, 1) U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item() # c*ln(L_k) u = self.factor * np.ceil(np.log(L_Q)).astype("int").item() # c*ln(L_q) U_part = U_part if U_part < L_K else L_K u = u if u < L_Q else L_Q scores_top, index = self._prob_QK(q, k, sample_k=U_part, n_top=u) # add scale factor scale = self.scale or 1.0 / sqrt(D) if scale is not None: scores_top = scores_top * scale # get the context context = self._get_initial_context(v, L_Q) # update the context with selected top_k queries context, attn = self._update_context(context, v, scores_top, index, L_Q, attn_mask) return context.transpose(2, 1).contiguous(), attn
[docs] class InformerEncoderLayer(nn.Module): def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): super().__init__() d_ff = d_ff or 4 * d_model self.attention = attention self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = F.relu if activation == "relu" else F.gelu
[docs] def forward(self, x, attn_mask=None): new_x, attn = self.attention(x, x, x, attn_mask=attn_mask) x = x + self.dropout(new_x) y = x = self.norm1(x) y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) y = self.dropout(self.conv2(y).transpose(-1, 1)) return self.norm2(x + y), attn
[docs] class InformerDecoderLayer(nn.Module): def __init__( self, self_attention, cross_attention, d_model, d_ff=None, dropout=0.1, activation="relu", ): super().__init__() d_ff = d_ff or 4 * d_model self.self_attention = self_attention self.cross_attention = cross_attention self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = F.relu if activation == "relu" else F.gelu
[docs] def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]) x = self.norm1(x) x = x + self.dropout(self.cross_attention(x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta)[0]) y = x = self.norm2(x) y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) y = self.dropout(self.conv2(y).transpose(-1, 1)) return self.norm3(x + y)