"""The optimizer wrapper for PyTorch Adam."""# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-ClausefromtypingimportIterable,Tuple,Optionalfromtorch.optimimportAdamastorch_Adamfrom.baseimportOptimizerfrom.lr_scheduler.baseimportLRScheduler
[docs]classAdam(Optimizer):"""The optimizer wrapper for PyTorch Adam :class:`torch.optim.Adam`. Parameters ---------- lr : float The learning rate of the optimizer. betas : Tuple[float, float] Coefficients used for computing running averages of gradient and its square. eps : float Term added to the denominator to improve numerical stability. weight_decay : float Weight decay (L2 penalty). amsgrad : bool Whether to use the AMSGrad variant of this algorithm from the paper :cite:`reddi2018OnTheConvergence`. lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler The learning rate scheduler of the optimizer. """def__init__(self,lr:float=0.001,betas:Tuple[float,float]=(0.9,0.999),eps:float=1e-08,weight_decay:float=0,amsgrad:bool=False,lr_scheduler:Optional[LRScheduler]=None,):super().__init__(lr,lr_scheduler)self.betas=betasself.eps=epsself.weight_decay=weight_decayself.amsgrad=amsgrad
[docs]definit_optimizer(self,params:Iterable)->None:"""Initialize the torch optimizer wrapped by this class. Parameters ---------- params : An iterable of ``torch.Tensor`` or ``dict``. Specifies what Tensors should be optimized. """self.torch_optimizer=torch_Adam(params=params,lr=self.lr,betas=self.betas,eps=self.eps,weight_decay=self.weight_decay,amsgrad=self.amsgrad,)ifself.lr_schedulerisnotNone:self.lr_scheduler.init_scheduler(self.torch_optimizer)