Source code for pypots.nn.modules.transformer.embedding

"""
Embedding methods for Transformer models are put here.


This implementation is inspired by the official one https://github.com/zhouhaoyi/Informer2020/blob/main/models/embed.py
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import math

import torch
import torch.fft
import torch.nn as nn


[docs] class PositionalEncoding(nn.Module): """The original positional-encoding module for Transformer. Parameters ---------- d_hid: The dimension of the hidden layer. n_positions: The max number of positions. """ def __init__(self, d_hid: int, n_positions: int = 1000): super().__init__() pe = torch.zeros(n_positions, d_hid, requires_grad=False).float() position = torch.arange(0, n_positions).float().unsqueeze(1) div_term = (torch.arange(0, d_hid, 2).float() * -(torch.log(torch.tensor(10000)) / d_hid)).exp() pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer("pos_table", pe)
[docs] def forward( self, x: torch.Tensor, dim: int = 1, return_only_pos: bool = False, ) -> torch.Tensor: """Forward processing of the positional encoding module. Parameters ---------- x: Input tensor. dim: The dimension to add the positional encoding. return_only_pos: Whether to return only the positional encoding. Returns ------- If return_only_pos is True: pos_enc: The positional encoding. else: x_with_pos: Output tensor, the input tensor with the positional encoding added. """ pos_enc = self.pos_table[:, : x.size(dim)].clone().detach() if return_only_pos: return pos_enc x_with_pos = x + pos_enc return x_with_pos
class TokenEmbedding(nn.Module): def __init__(self, c_in, d_model): super().__init__() padding = 1 if torch.__version__ >= "1.5.0" else 2 self.tokenConv = nn.Conv1d( in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode="circular", bias=False, ) for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="leaky_relu") def forward(self, x): x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) return x class FixedEmbedding(nn.Module): def __init__(self, c_in, d_model): super().__init__() w = torch.zeros(c_in, d_model).float() w.require_grad = False position = torch.arange(0, c_in).float().unsqueeze(1) div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() w[:, 0::2] = torch.sin(position * div_term) w[:, 1::2] = torch.cos(position * div_term) self.emb = nn.Embedding(c_in, d_model) self.emb.weight = nn.Parameter(w, requires_grad=False) def forward(self, x): return self.emb(x).detach() class TemporalEmbedding(nn.Module): def __init__(self, d_model, embed_type="fixed", freq="h"): super().__init__() minute_size = 4 hour_size = 24 weekday_size = 7 day_size = 32 month_size = 13 Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding if freq == "t": self.minute_embed = Embed(minute_size, d_model) self.hour_embed = Embed(hour_size, d_model) self.weekday_embed = Embed(weekday_size, d_model) self.day_embed = Embed(day_size, d_model) self.month_embed = Embed(month_size, d_model) def forward(self, x): x = x.long() minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 hour_x = self.hour_embed(x[:, :, 3]) weekday_x = self.weekday_embed(x[:, :, 2]) day_x = self.day_embed(x[:, :, 1]) month_x = self.month_embed(x[:, :, 0]) return hour_x + weekday_x + day_x + month_x + minute_x class TimeFeatureEmbedding(nn.Module): def __init__(self, d_model, freq="h"): super().__init__() freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} d_inp = freq_map[freq] self.embed = nn.Linear(d_inp, d_model, bias=False) def forward(self, x): return self.embed(x) class DataEmbedding(nn.Module): def __init__( self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1, with_pos=True, n_max_steps=1000, ): super().__init__() self.with_pos = with_pos self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) if with_pos: self.position_embedding = PositionalEncoding(d_hid=d_model, n_positions=n_max_steps) self.temporal_embedding = ( TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type != "timeF" else TimeFeatureEmbedding(d_model=d_model, freq=freq) ) self.dropout = nn.Dropout(p=dropout) def forward(self, x, x_timestamp=None): if x_timestamp is None: x = self.value_embedding(x) if self.with_pos: x += self.position_embedding(x, return_only_pos=True) else: x = self.value_embedding(x) + self.temporal_embedding(x_timestamp) if self.with_pos: x += self.position_embedding(x, return_only_pos=True) return self.dropout(x)