Source code for pypots.nn.modules.transformer.layers

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from .attention import MultiHeadAttention, AttentionOperator


[docs] class PositionWiseFeedForward(nn.Module): """Position-wise feed forward network (FFN) in Transformer. Parameters ---------- d_in: The dimension of the input tensor. d_hid: The dimension of the hidden layer. dropout: The dropout rate. """ def __init__(self, d_in: int, d_hid: int, dropout: float = 0.1): super().__init__() self.linear_1 = nn.Linear(d_in, d_hid) self.linear_2 = nn.Linear(d_hid, d_in) self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) self.dropout = nn.Dropout(dropout)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward processing of the position-wise feed forward network. Parameters ---------- x: Input tensor. Returns ------- x: Output tensor. """ # save the original input for the later residual connection residual = x # the 1st linear processing and ReLU non-linear projection x = F.relu(self.linear_1(x)) # the 2nd linear processing x = self.linear_2(x) # apply dropout x = self.dropout(x) # apply residual connection x += residual # apply layer-norm x = self.layer_norm(x) return x
[docs] class TransformerEncoderLayer(nn.Module): """Transformer encoder layer. Parameters ---------- attn_opt: The attention operator for the multi-head attention module in the encoder layer. d_model: The dimension of the input tensor. n_heads: The number of heads in multi-head attention. d_k: The dimension of the key and query tensor. d_v: The dimension of the value tensor. d_ffn: The dimension of the hidden layer. dropout: The dropout rate. """ def __init__( self, attn_opt: AttentionOperator, d_model: int, n_heads: int, d_k: int, d_v: int, d_ffn: int, dropout: float = 0.1, ): super().__init__() self.slf_attn = MultiHeadAttention( attn_opt, d_model, n_heads, d_k, d_v, ) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout)
[docs] def forward( self, enc_input: torch.Tensor, src_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward processing of the encoder layer. Parameters ---------- enc_input: Input tensor. src_mask: Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps]. Returns ------- enc_output: Output tensor. attn_weights: The attention map. """ enc_output, attn_weights = self.slf_attn( enc_input, enc_input, enc_input, attn_mask=src_mask, **kwargs, ) # apply dropout and residual connection enc_output = self.dropout(enc_output) enc_output += enc_input # apply layer-norm enc_output = self.layer_norm(enc_output) enc_output = self.pos_ffn(enc_output) return enc_output, attn_weights
[docs] class TransformerDecoderLayer(nn.Module): """Transformer decoder layer. Parameters ---------- slf_attn_opt: The attention operator for the multi-head attention module in the decoder layer. enc_attn_opt: The attention operator for the encoding multi-head attention module in the decoder layer. d_model: The dimension of the input tensor. n_heads: The number of heads in multi-head attention. d_k: The dimension of the key and query tensor. d_v: The dimension of the value tensor. d_ffn: The dimension of the hidden layer. dropout: The dropout rate. """ def __init__( self, slf_attn_opt: AttentionOperator, enc_attn_opt: AttentionOperator, d_model: int, n_heads: int, d_k: int, d_v: int, d_ffn: int, dropout: float = 0.1, ): super().__init__() self.slf_attn = MultiHeadAttention( slf_attn_opt, d_model, n_heads, d_k, d_v, ) self.enc_attn = MultiHeadAttention( enc_attn_opt, d_model, n_heads, d_k, d_v, ) self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout)
[docs] def forward( self, dec_input: torch.Tensor, enc_output: torch.Tensor, slf_attn_mask: Optional[torch.Tensor] = None, dec_enc_attn_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward processing of the decoder layer. Parameters ---------- dec_input: Input tensor. enc_output: Output tensor from the encoder. slf_attn_mask: Masking tensor for the self-attention module. The shape should be [batch_size, n_heads, n_steps, n_steps]. dec_enc_attn_mask: Masking tensor for the encoding attention module. The shape should be [batch_size, n_heads, n_steps, n_steps]. Returns ------- dec_output: Output tensor. dec_slf_attn: The self-attention map. dec_enc_attn: The encoding attention map. """ dec_output, dec_slf_attn = self.slf_attn( dec_input, dec_input, dec_input, attn_mask=slf_attn_mask, **kwargs, ) dec_output, dec_enc_attn = self.enc_attn( dec_output, enc_output, enc_output, attn_mask=dec_enc_attn_mask, **kwargs, ) dec_output = self.pos_ffn(dec_output) return dec_output, dec_slf_attn, dec_enc_attn