Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for pypots.nn.modules.grud.backbone
""" """
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
from typing import Tuple
import torch
import torch.nn as nn
from .layers import TemporalDecay
[docs]
class BackboneGRUD ( nn . Module ):
def __init__ (
self ,
n_steps : int ,
n_features : int ,
rnn_hidden_size : int ,
):
super () . __init__ ()
self . n_steps = n_steps
self . n_features = n_features
self . rnn_hidden_size = rnn_hidden_size
# create models
self . rnn_cell = nn . GRUCell ( self . n_features * 2 + self . rnn_hidden_size , self . rnn_hidden_size )
self . temp_decay_h = TemporalDecay ( input_size = self . n_features , output_size = self . rnn_hidden_size , diag = False )
self . temp_decay_x = TemporalDecay ( input_size = self . n_features , output_size = self . n_features , diag = True )
[docs]
def forward ( self , X , missing_mask , deltas , empirical_mean , X_filledLOCF ) -> Tuple [ torch . Tensor , ... ]:
"""Forward processing of GRU-D.
Parameters
----------
X:
missing_mask:
deltas:
empirical_mean:
X_filledLOCF:
Returns
-------
classification_pred:
logits:
"""
hidden_state = torch . zeros (( X . size ()[ 0 ], self . rnn_hidden_size ), device = X . device )
representation_collector = []
for t in range ( self . n_steps ):
# for data, [batch, time, features]
x = X [:, t , :] # values
m = missing_mask [:, t , :] # mask
d = deltas [:, t , :] # delta, time gap
x_filledLOCF = X_filledLOCF [:, t , :]
gamma_h = self . temp_decay_h ( d )
gamma_x = self . temp_decay_x ( d )
hidden_state = hidden_state * gamma_h
representation_collector . append ( hidden_state )
x_h = gamma_x * x_filledLOCF + ( 1 - gamma_x ) * empirical_mean
x_replaced = m * x + ( 1 - m ) * x_h
data_input = torch . cat ([ x_replaced , hidden_state , m ], dim = 1 )
hidden_state = self . rnn_cell ( data_input , hidden_state )
representation_collector = torch . stack ( representation_collector , dim = 1 )
return representation_collector , hidden_state