Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for pypots.nn.modules.mrnn.layers
""" """
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
[docs]
class MrnnFcnRegression ( nn . Module ):
"""M-RNN fully connection regression Layer"""
def __init__ ( self , feature_num ):
super () . __init__ ()
self . U = Parameter ( torch . Tensor ( feature_num , feature_num ))
self . V1 = Parameter ( torch . Tensor ( feature_num , feature_num ))
self . V2 = Parameter ( torch . Tensor ( feature_num , feature_num ))
self . beta = Parameter ( torch . Tensor ( feature_num )) # bias beta
self . final_linear = nn . Linear ( feature_num , feature_num )
m = torch . ones ( feature_num , feature_num ) - torch . eye ( feature_num , feature_num )
self . register_buffer ( "m" , m )
self . reset_parameters ()
def reset_parameters ( self ):
stdv = 1.0 / math . sqrt ( self . U . size ( 0 ))
self . U . data . uniform_ ( - stdv , stdv )
self . V1 . data . uniform_ ( - stdv , stdv )
self . V2 . data . uniform_ ( - stdv , stdv )
self . beta . data . uniform_ ( - stdv , stdv )
[docs]
def forward ( self , x , missing_mask , target ):
h_t = torch . sigmoid (
F . linear ( x , self . U * self . m )
+ F . linear ( target , self . V1 * self . m )
+ F . linear ( missing_mask , self . V2 )
+ self . beta
)
x_hat_t = torch . sigmoid ( self . final_linear ( h_t ))
return x_hat_t