Source code for pypots.nn.modules.imputeformer.mlp
"""The implementation of the MLPs for ImputeFormer :cite:`nie2024imputeformer`"""# Created by Tong Nie <nietong@tongji.edu.cn> and Wenjie Du <wenjay.du@gmail.com># License: BSD-3-Clauseimporttorch.nnasnnclassDense(nn.Module):"""A simple fully-connected layer."""def__init__(self,input_size,output_size,dropout=0.0,bias=True):super().__init__()self.layer=nn.Sequential(nn.Linear(input_size,output_size,bias=bias),nn.ReLU(),nn.Dropout(dropout)ifdropout>0.0elsenn.Identity(),)defforward(self,x):returnself.layer(x)
[docs]classMLP(nn.Module):""" Simple Multi-layer Perceptron encoder with optional linear readout. """def__init__(self,input_size,hidden_size,output_size=None,n_layers=1,dropout=0.0):super().__init__()layers=[Dense(input_size=input_sizeifi==0elsehidden_size,output_size=hidden_size,dropout=dropout,)foriinrange(n_layers)]self.mlp=nn.Sequential(*layers)ifoutput_sizeisnotNone:self.readout=nn.Linear(hidden_size,output_size)else:self.register_parameter("readout",None)defforward(self,x,u=None):""""""out=self.mlp(x)ifself.readoutisnotNone:returnself.readout(out)returnout