Source code for pypots.nn.modules.stemgnn.backbone

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.nn as nn
import torch.nn.functional as F

from .layers import StockBlockLayer


[docs] class BackboneStemGNN(nn.Module): def __init__( self, units, stack_cnt, time_step, multi_layer, horizon=1, dropout_rate=0.5, leaky_rate=0.2, ): super().__init__() self.unit = units self.stack_cnt = stack_cnt self.unit = units self.alpha = leaky_rate self.time_step = time_step self.horizon = horizon self.weight_key = nn.Parameter(torch.zeros(size=(self.unit, 1))) nn.init.xavier_uniform_(self.weight_key.data, gain=1.414) self.weight_query = nn.Parameter(torch.zeros(size=(self.unit, 1))) nn.init.xavier_uniform_(self.weight_query.data, gain=1.414) self.GRU = nn.GRU(self.time_step, self.unit) self.multi_layer = multi_layer self.stock_block = nn.ModuleList() self.stock_block.extend( [StockBlockLayer(self.time_step, self.unit, self.multi_layer, stack_cnt=i) for i in range(self.stack_cnt)] ) self.fc = nn.Sequential( nn.Linear(int(self.time_step), int(self.time_step)), nn.LeakyReLU(), nn.Linear(int(self.time_step), self.horizon), ) self.leakyrelu = nn.LeakyReLU(self.alpha) self.dropout = nn.Dropout(p=dropout_rate)
[docs] @staticmethod def get_laplacian(graph, normalize): """ return the laplacian of the graph. :param graph: the graph structure without self loop, [N, N]. :param normalize: whether to used the normalized laplacian. :return: graph laplacian. """ if normalize: D = torch.diag(torch.sum(graph, dim=-1) ** (-1 / 2)) L = torch.eye(graph.size(0), device=graph.device, dtype=graph.dtype) - torch.mm(torch.mm(D, graph), D) else: D = torch.diag(torch.sum(graph, dim=-1)) L = D - graph return L
[docs] @staticmethod def cheb_polynomial(laplacian): """ Compute the Chebyshev Polynomial, according to the graph laplacian. :param laplacian: the graph laplacian, [N, N]. :return: the multi order Chebyshev laplacian, [K, N, N]. """ N = laplacian.size(0) # [N, N] laplacian = laplacian.unsqueeze(0) first_laplacian = torch.zeros([1, N, N], device=laplacian.device, dtype=torch.float) second_laplacian = laplacian third_laplacian = (2 * torch.matmul(laplacian, second_laplacian)) - first_laplacian forth_laplacian = 2 * torch.matmul(laplacian, third_laplacian) - second_laplacian multi_order_laplacian = torch.cat([first_laplacian, second_laplacian, third_laplacian, forth_laplacian], dim=0) return multi_order_laplacian
def latent_correlation_layer(self, x): input, _ = self.GRU(x.permute(2, 0, 1).contiguous()) input = input.permute(1, 0, 2).contiguous() attention = self.self_graph_attention(input) attention = torch.mean(attention, dim=0) degree = torch.sum(attention, dim=1) # laplacian is sym or not attention = 0.5 * (attention + attention.T) degree_l = torch.diag(degree) diagonal_degree_hat = torch.diag(1 / (torch.sqrt(degree) + 1e-7)) laplacian = torch.matmul(diagonal_degree_hat, torch.matmul(degree_l - attention, diagonal_degree_hat)) mul_L = self.cheb_polynomial(laplacian) return mul_L, attention def self_graph_attention(self, input): input = input.permute(0, 2, 1).contiguous() bat, N, fea = input.size() key = torch.matmul(input, self.weight_key) query = torch.matmul(input, self.weight_query) data = key.repeat(1, 1, N).view(bat, N * N, 1) + query.repeat(1, N, 1) data = data.squeeze(2) data = data.view(bat, N, -1) data = self.leakyrelu(data) attention = F.softmax(data, dim=2) attention = self.dropout(attention) return attention @staticmethod def graph_fft(input, eigenvectors): return torch.matmul(eigenvectors, input)
[docs] def forward(self, x): mul_L, attention = self.latent_correlation_layer(x) X = x.unsqueeze(1).permute(0, 1, 3, 2).contiguous() result = None for stack_i in range(self.stack_cnt): forecast, X = self.stock_block[stack_i](X, mul_L) if stack_i == 0: result = forecast else: result += forecast # residual connection forecast_result = self.fc(result) if forecast_result.size()[-1] == 1: return forecast_result.unsqueeze(1).squeeze(-1), attention else: return forecast_result.permute(0, 2, 1).contiguous(), attention