Source code for pypots.nn.modules.grud.backbone

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Tuple

import torch
import torch.nn as nn

from .layers import TemporalDecay


[docs] class BackboneGRUD(nn.Module): def __init__( self, n_steps: int, n_features: int, rnn_hidden_size: int, ): super().__init__() self.n_steps = n_steps self.n_features = n_features self.rnn_hidden_size = rnn_hidden_size # create models self.rnn_cell = nn.GRUCell(self.n_features * 2 + self.rnn_hidden_size, self.rnn_hidden_size) self.temp_decay_h = TemporalDecay(input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False) self.temp_decay_x = TemporalDecay(input_size=self.n_features, output_size=self.n_features, diag=True)
[docs] def forward(self, X, missing_mask, deltas, empirical_mean, X_filledLOCF) -> Tuple[torch.Tensor, ...]: """Forward processing of GRU-D. Parameters ---------- X: missing_mask: deltas: empirical_mean: X_filledLOCF: Returns ------- classification_pred: logits: """ hidden_state = torch.zeros((X.size()[0], self.rnn_hidden_size), device=X.device) representation_collector = [] for t in range(self.n_steps): # for data, [batch, time, features] x = X[:, t, :] # values m = missing_mask[:, t, :] # mask d = deltas[:, t, :] # delta, time gap x_filledLOCF = X_filledLOCF[:, t, :] gamma_h = self.temp_decay_h(d) gamma_x = self.temp_decay_x(d) hidden_state = hidden_state * gamma_h representation_collector.append(hidden_state) x_h = gamma_x * x_filledLOCF + (1 - gamma_x) * empirical_mean x_replaced = m * x + (1 - m) * x_h data_input = torch.cat([x_replaced, hidden_state, m], dim=1) hidden_state = self.rnn_cell(data_input, hidden_state) representation_collector = torch.stack(representation_collector, dim=1) return representation_collector, hidden_state