"""The implementation of Mean value imputation."""# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-ClauseimportwarningsfromtypingimportUnion,Optionalimporth5pyimportnumpyasnpimporttorchfrom..baseimportBaseImputer
[docs]classMean(BaseImputer):"""Mean value imputation method."""def__init__(self,):super().__init__()
[docs]deffit(self,train_set:Union[dict,str],val_set:Optional[Union[dict,str]]=None,file_type:str="hdf5",)->None:"""Train the imputer on the given data. Warnings -------- Mean imputation class does not need to run fit(). Please run func ``predict()`` directly. """warnings.warn("Mean imputation class has no parameter to train. Please run func `predict()` directly.")
[docs]defpredict(self,test_set:Union[dict,str],file_type:str="hdf5",**kwargs,)->dict:ifisinstance(test_set,str):withh5py.File(test_set,"r")asf:X=f["X"][:]else:X=test_set["X"]assertlen(X.shape)==3,(f"Input X should have 3 dimensions [n_samples, n_steps, n_features], "f"but the actual shape of X: {X.shape}")ifisinstance(X,list):X=np.asarray(X)n_samples,n_steps,n_features=X.shapeifisinstance(X,np.ndarray):X_imputed_reshaped=np.copy(X).reshape(-1,n_features)mean_values=np.nanmean(X_imputed_reshaped,axis=0)fori,vinenumerate(mean_values):X_imputed_reshaped[:,i]=np.nan_to_num(X_imputed_reshaped[:,i],nan=v)imputed_data=X_imputed_reshaped.reshape(n_samples,n_steps,n_features)elifisinstance(X,torch.Tensor):X_imputed_reshaped=torch.clone(X).reshape(-1,n_features)mean_values=torch.nanmean(X_imputed_reshaped,dim=0).numpy()fori,vinenumerate(mean_values):X_imputed_reshaped[:,i]=torch.nan_to_num(X_imputed_reshaped[:,i],nan=v)imputed_data=X_imputed_reshaped.reshape(n_samples,n_steps,n_features)else:raiseValueError()result_dict={"imputation":imputed_data,}returnresult_dict