""" """# Created by Jun Wang <jwangfx@connect.ust.hk> and Wenjie Du <wenjay.du@gmail.com># License: BSD-3-ClausefromtypingimportTupleimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFfrom.layersimportUsganDiscriminatorfrom..britsimportBackboneBRITSfrom....nn.functionalimportcalc_mse
[docs]defforward(self,inputs:dict,training_object:str="generator",)->Tuple[torch.Tensor,...]:(imputed_data,f_reconstruction,b_reconstruction,_,_,_,_,)=self.generator(inputs)reconstruction=(f_reconstruction+b_reconstruction)/2# if in training mode, return results with lossesifself.training:forward_X=inputs["forward"]["X"]forward_missing_mask=inputs["forward"]["missing_mask"]iftraining_object=="discriminator":discrimination=self.discriminator(imputed_data.detach(),forward_missing_mask)l_D=F.binary_cross_entropy_with_logits(discrimination,forward_missing_mask)discrimination_loss=l_Dreturnimputed_data,reconstruction,discrimination_losselse:discrimination=self.discriminator(imputed_data,forward_missing_mask)l_G=-F.binary_cross_entropy_with_logits(discrimination,forward_missing_mask,weight=1-forward_missing_mask,)reconstruction=(f_reconstruction+b_reconstruction)/2reconstruction_loss=calc_mse(forward_X,reconstruction,forward_missing_mask)+0.1*calc_mse(f_reconstruction,b_reconstruction)loss_gene=l_G+self.lambda_mse*reconstruction_lossgeneration_loss=loss_genereturnimputed_data,reconstruction,generation_losselse:returnimputed_data,reconstruction