""" """# Created by Wenjie Du <wdu@time-series.ai># License: BSD-3-Clauseimporttorchimporttorch.nnasnn
[docs]classBackboneFITS(nn.Module):def__init__(self,n_steps:int,n_features:int,n_pred_steps:int,cut_freq:int,individual:bool,):super().__init__()self.n_steps=n_stepsself.n_features=n_featuresself.n_pred_steps=n_pred_stepsself.individual=individualself.dominance_freq=cut_freqself.length_ratio=(n_steps+n_pred_steps)/n_stepsifself.individual:self.freq_upsampler=nn.ModuleList()foriinrange(self.n_features):self.freq_upsampler.append(nn.Linear(self.dominance_freq,int(self.dominance_freq*self.length_ratio)).to(torch.cfloat))else:# complex layer for frequency upsamplingself.freq_upsampler=nn.Linear(self.dominance_freq,int(self.dominance_freq*self.length_ratio)).to(torch.cfloat)
[docs]defforward(self,x):low_specx=torch.fft.rfft(x,dim=1)assertlow_specx.size(1)>=self.dominance_freq,(f"The sequence length after FFT {low_specx.size(1)} is less than the cut frequency {self.dominance_freq}. "f"Please check the input sequence length, or decrease the cut frequency.")low_specx[:,self.dominance_freq:]=0# LPFlow_specx=low_specx[:,0:self.dominance_freq,:]# LPFifself.individual:low_specxy_=torch.zeros([low_specx.size(0),int(self.dominance_freq*self.length_ratio),low_specx.size(2)],dtype=low_specx.dtype,).to(low_specx.device)foriinrange(self.n_features):low_specxy_[:,:,i]=self.freq_upsampler[i](low_specx[:,:,i].permute(0,1)).permute(0,1)else:low_specxy_=self.freq_upsampler(low_specx.permute(0,2,1)).permute(0,2,1)low_specxy=torch.zeros([low_specxy_.size(0),int((self.n_steps+self.n_pred_steps)/2+1),low_specxy_.size(2)],dtype=low_specxy_.dtype,).to(low_specxy_.device)low_specxy[:,0:low_specxy_.size(1),:]=low_specxy_# zero paddinglow_xy=torch.fft.irfft(low_specxy,dim=1)low_xy=low_xy*self.length_ratio# energy compensation for the length changereturnlow_xy