""" """# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-Clauseimporttorchimporttorch.fftimporttorch.nnasnn
[docs]classRevIN(nn.Module):"""RevIN: Reversible Inference Network. Parameters ---------- n_features : the number of features or channels eps : a value added for numerical stability affine : if True, RevIN has learnable affine parameters """def__init__(self,n_features:int,eps:float=1e-9,affine:bool=True,):super().__init__()self.n_features=n_featuresself.eps=epsself.affine=affineifself.affine:self._init_params()
def_init_params(self):# initialize RevIN params: (C,)self.affine_weight=nn.Parameter(torch.ones(self.n_features))self.affine_bias=nn.Parameter(torch.zeros(self.n_features))def_normalize(self,x,missing_mask=None):dim2reduce=tuple(range(1,x.ndim-1))# calculate mean and stdevifmissing_maskisNone:# original implementationmean=torch.mean(x,dim=dim2reduce,keepdim=True)stdev=torch.sqrt(torch.var(x,dim=dim2reduce,keepdim=True,unbiased=False)+self.eps)else:# pypots implementation for POTS datamissing_sum=torch.sum(missing_mask==1,dim=dim2reduce,keepdim=True)+self.epsmean=torch.sum(x,dim=dim2reduce,keepdim=True)/missing_sumx_enc=x.masked_fill(missing_mask==0,0)variance=torch.sum(x_enc*x_enc,dim=dim2reduce,keepdim=True)+self.epsstdev=torch.sqrt(variance/missing_sum)# detach mean and stdev to avoid backpropagationself.mean=mean.detach()self.stdev=stdev.detach()# normalize the inputx=x-self.meanx=x/self.stdevifself.affine:# apply affine transformationx=x*self.affine_weightx=x+self.affine_biasreturnxdef_denormalize(self,x):# reverse affine transformationifself.affine:x=x-self.affine_biasx=x/(self.affine_weight+self.eps)# denormalize the inputx=x*self.stdevx=x+self.meanreturnx