Source code for pypots.nn.modules.grud.layers

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter


[docs] class TemporalDecay(nn.Module): """The module used to generate the temporal decay factor gamma in the GRU-D model. Please refer to the original paper :cite:`che2018GRUD` for more details. Attributes ---------- W: tensor, The weights (parameters) of the module. b: tensor, The bias of the module. Parameters ---------- input_size : int, the feature dimension of the input output_size : int, the feature dimension of the output diag : bool, whether to product the weight with an identity matrix before forward processing References ---------- .. [1] `Che, Zhengping, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu. "Recurrent neural networks for multivariate time series with missing values." Scientific reports 8, no. 1 (2018): 6085. <https://www.nature.com/articles/s41598-018-24271-9.pdf>`_ """ def __init__(self, input_size: int, output_size: int, diag: bool = False): super().__init__() self.diag = diag self.W = Parameter(torch.Tensor(output_size, input_size)) self.b = Parameter(torch.Tensor(output_size)) if self.diag: assert input_size == output_size m = torch.eye(input_size, input_size) self.register_buffer("m", m) self._reset_parameters() def _reset_parameters(self) -> None: std_dev = 1.0 / math.sqrt(self.W.size(0)) self.W.data.uniform_(-std_dev, std_dev) if self.b is not None: self.b.data.uniform_(-std_dev, std_dev)
[docs] def forward(self, delta: torch.Tensor) -> torch.Tensor: """Forward processing of this NN module. Parameters ---------- delta : tensor, shape [n_samples, n_steps, n_features] The time gaps. Returns ------- gamma : tensor, of the same shape with parameter `delta`, values in (0,1] The temporal decay factor. """ if self.diag: gamma = F.relu(F.linear(delta, self.W * Variable(self.m), self.b)) else: gamma = F.relu(F.linear(delta, self.W, self.b)) gamma = torch.exp(-gamma) return gamma