""" """# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-Clauseimporttorchfromtorch.nn.modules.lossimport_Lossfrom..functionalimport(calc_mae,calc_mse,calc_rmse,calc_mre,)
[docs]classCriterion(_Loss):def__init__(self,lower_better:bool=True,):"""The base class for all class implementation loss functions and metrics in PyPOTS. Parameters ---------- lower_better : Whether the lower value of the criterion directs to a better model performance. Default as True which is the case for most loss functions (e.g. MSE, Cross Entropy). If False, it makes that the higher value leads to a better model performance (e.g. Accuracy). """super().__init__()self.lower_better=lower_better
[docs]defforward(self,logits:torch.Tensor,targets:torch.Tensor,)->torch.Tensor:"""The criterion calculation process. Parameters ---------- logits: The model outputs, predicted unnormalized logits. targets: The ground truth values. """raiseNotImplementedError