Source code for pypots.nn.modules.brits.layers

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter


[docs] class FeatureRegression(nn.Module): """The module used to capture the correlation between features for imputation in BRITS. Attributes ---------- W : tensor The weights (parameters) of the module. b : tensor The bias of the module. m (buffer) : tensor The mask matrix, a squire matrix with diagonal entries all zeroes while left parts all ones. It is applied to the weight matrix to mask out the estimation contributions from features themselves. It is used to help enhance the imputation performance of the network. Parameters ---------- input_size : the feature dimension of the input """ def __init__(self, input_size: int): super().__init__() self.W = Parameter(torch.Tensor(input_size, input_size)) self.b = Parameter(torch.Tensor(input_size)) m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size) self.register_buffer("m", m) self._reset_parameters() def _reset_parameters(self) -> None: std_dev = 1.0 / math.sqrt(self.W.size(0)) self.W.data.uniform_(-std_dev, std_dev) if self.b is not None: self.b.data.uniform_(-std_dev, std_dev)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward processing of the NN module. Parameters ---------- x : tensor, the input for processing Returns ------- output: tensor, the processed result containing imputation from feature regression """ output = F.linear(x, self.W * Variable(self.m), self.b) return output