""" """# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-Clauseimportmathimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFfromtorch.autogradimportVariablefromtorch.nn.parameterimportParameter
[docs]classFeatureRegression(nn.Module):"""The module used to capture the correlation between features for imputation in BRITS. Attributes ---------- W : tensor The weights (parameters) of the module. b : tensor The bias of the module. m (buffer) : tensor The mask matrix, a squire matrix with diagonal entries all zeroes while left parts all ones. It is applied to the weight matrix to mask out the estimation contributions from features themselves. It is used to help enhance the imputation performance of the network. Parameters ---------- input_size : the feature dimension of the input """def__init__(self,input_size:int):super().__init__()self.W=Parameter(torch.Tensor(input_size,input_size))self.b=Parameter(torch.Tensor(input_size))m=torch.ones(input_size,input_size)-torch.eye(input_size,input_size)self.register_buffer("m",m)self._reset_parameters()def_reset_parameters(self)->None:std_dev=1.0/math.sqrt(self.W.size(0))self.W.data.uniform_(-std_dev,std_dev)ifself.bisnotNone:self.b.data.uniform_(-std_dev,std_dev)
[docs]defforward(self,x:torch.Tensor)->torch.Tensor:"""Forward processing of the NN module. Parameters ---------- x : tensor, the input for processing Returns ------- output: tensor, the processed result containing imputation from feature regression """output=F.linear(x,self.W*Variable(self.m),self.b)returnoutput