Source code for pypots.nn.modules.koopa.layers

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import math

import torch
import torch.fft
import torch.nn as nn


[docs] class FourierFilter(nn.Module): """ Fourier Filter: to time-variant and time-invariant term """ def __init__(self, mask_spectrum): super().__init__() self.mask_spectrum = mask_spectrum
[docs] def forward(self, x): xf = torch.fft.rfft(x, dim=1) mask = torch.ones_like(xf) mask[:, self.mask_spectrum, :] = 0 x_var = torch.fft.irfft(xf * mask, dim=1) x_inv = x - x_var return x_var, x_inv
[docs] class MLP(nn.Module): """ Multilayer perceptron to encode/decode high dimension representation of sequential data """ def __init__( self, d_in, d_out, d_hidden=128, n_hidden_layers=2, dropout=0.05, activation="tanh", ): super().__init__() self.d_in = d_in self.d_out = d_out self.d_hidden = d_hidden self.n_hidden_layers = n_hidden_layers self.dropout = dropout if activation == "relu": self.activation = nn.ReLU() elif activation == "tanh": self.activation = nn.Tanh() else: raise NotImplementedError layers = [ nn.Linear(self.d_in, self.d_hidden), self.activation, nn.Dropout(self.dropout), ] for i in range(self.n_hidden_layers - 2): layers += [ nn.Linear(self.d_hidden, self.d_hidden), self.activation, nn.Dropout(dropout), ] layers += [nn.Linear(d_hidden, d_out)] self.layers = nn.Sequential(*layers)
[docs] def forward(self, x): # x: B x S x f_in # y: B x S x f_out y = self.layers(x) return y
class KPLayer(nn.Module): """ A demonstration of finding one step transition of linear system by DMD iteratively """ def __init__(self): super().__init__() self.K = None # B E E def one_step_forward(self, z, return_rec=False, return_K=False): B, input_len, E = z.shape assert input_len > 1, "snapshots number should be larger than 1" x, y = z[:, :-1], z[:, 1:] # solve linear system self.K = torch.linalg.lstsq(x, y).solution # B E E if torch.isnan(self.K).any(): print("Encounter K with nan, replace K by identity matrix") self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1) z_pred = torch.bmm(z[:, -1:], self.K) if return_rec: z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) return z_rec, z_pred return z_pred def forward(self, z, pred_len=1): assert pred_len >= 1, "prediction length should not be less than 1" z_rec, z_pred = self.one_step_forward(z, return_rec=True) z_preds = [z_pred] for i in range(1, pred_len): z_pred = torch.bmm(z_pred, self.K) z_preds.append(z_pred) z_preds = torch.cat(z_preds, dim=1) return z_rec, z_preds class KPLayerApprox(nn.Module): """ Find koopman transition of linear system by DMD with multistep K approximation """ def __init__(self): super().__init__() self.K = None # B E E self.K_step = None # B E E def forward(self, z, pred_len=1): # z: B L E, koopman invariance space representation # z_rec: B L E, reconstructed representation # z_pred: B S E, forecasting representation B, input_len, E = z.shape assert input_len > 1, "snapshots number should be larger than 1" x, y = z[:, :-1], z[:, 1:] # solve linear system self.K = torch.linalg.lstsq(x, y).solution # B E E if torch.isnan(self.K).any(): print("Encounter K with nan, replace K by identity matrix") self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1) z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) # B L E if pred_len <= input_len: self.K_step = torch.linalg.matrix_power(self.K, pred_len) if torch.isnan(self.K_step).any(): print("Encounter multistep K with nan, replace it by identity matrix") self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1) z_pred = torch.bmm(z[:, -pred_len:, :], self.K_step) else: self.K_step = torch.linalg.matrix_power(self.K, input_len) if torch.isnan(self.K_step).any(): print("Encounter multistep K with nan, replace it by identity matrix") self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1) temp_z_pred, all_pred = z, [] for _ in range(math.ceil(pred_len / input_len)): temp_z_pred = torch.bmm(temp_z_pred, self.K_step) all_pred.append(temp_z_pred) z_pred = torch.cat(all_pred, dim=1)[:, :pred_len, :] return z_rec, z_pred
[docs] class TimeVarKP(nn.Module): """ Koopman Predictor with DMD (analysitical solution of Koopman operator) Utilize local variations within individual sliding window to predict the future of time-variant term """ def __init__( self, enc_in=8, input_len=96, pred_len=96, seg_len=24, dynamic_dim=128, encoder=None, decoder=None, multistep=False, ): super().__init__() self.input_len = input_len self.pred_len = pred_len self.enc_in = enc_in self.seg_len = seg_len self.dynamic_dim = dynamic_dim self.multistep = multistep self.encoder, self.decoder = encoder, decoder self.freq = math.ceil(self.input_len / self.seg_len) # segment number of input self.step = math.ceil(self.pred_len / self.seg_len) # segment number of output self.padding_len = self.seg_len * self.freq - self.input_len # Approximate mulitstep K by KPLayerApprox when pred_len is large self.dynamics = KPLayerApprox() if self.multistep else KPLayer()
[docs] def forward(self, x): # x: B L C B, L, C = x.shape res = torch.cat((x[:, L - self.padding_len :, :], x), dim=1) res = res.chunk(self.freq, dim=1) # F x B P C, P means seg_len res = torch.stack(res, dim=1).reshape(B, self.freq, -1) # B F PC res = self.encoder(res) # B F H x_rec, x_pred = self.dynamics(res, self.step) # B F H, B S H x_rec = self.decoder(x_rec) # B F PC x_rec = x_rec.reshape(B, self.freq, self.seg_len, self.enc_in) x_rec = x_rec.reshape(B, -1, self.enc_in)[:, : self.input_len, :] # B L C x_pred = self.decoder(x_pred) # B S PC x_pred = x_pred.reshape(B, self.step, self.seg_len, self.enc_in) x_pred = x_pred.reshape(B, -1, self.enc_in)[:, : self.pred_len, :] # B S C return x_rec, x_pred
[docs] class TimeInvKP(nn.Module): """ Koopman Predictor with learnable Koopman operator Utilize lookback and forecast window snapshots to predict the future of time-invariant term """ def __init__(self, input_len=96, pred_len=96, dynamic_dim=128, encoder=None, decoder=None): super().__init__() self.dynamic_dim = dynamic_dim self.input_len = input_len self.pred_len = pred_len self.encoder = encoder self.decoder = decoder K_init = torch.randn(self.dynamic_dim, self.dynamic_dim) U, _, V = torch.svd(K_init) # stable initialization self.K = nn.Linear(self.dynamic_dim, self.dynamic_dim, bias=False) self.K.weight.data = torch.mm(U, V.t())
[docs] def forward(self, x): # x: B L C res = x.transpose(1, 2) # B C L res = self.encoder(res) # B C H res = self.K(res) # B C H res = self.decoder(res) # B C S res = res.transpose(1, 2) # B S C return res