"""The optimizer wrapper for PyTorch SGD :class:`torch.optim.SGD`."""# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-ClausefromtypingimportIterable,Optionalfromtorch.optimimportSGDastorch_SGDfrom.baseimportOptimizerfrom.lr_scheduler.baseimportLRScheduler
[docs]classSGD(Optimizer):"""The optimizer wrapper for PyTorch SGD :class:`torch.optim.SGD`. Parameters ---------- lr : float The learning rate of the optimizer. momentum : float Momentum factor. weight_decay : float Weight decay (L2 penalty). dampening : float Dampening for momentum. nesterov : bool Whether to enable Nesterov momentum. lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler The learning rate scheduler of the optimizer. """def__init__(self,lr:float=0.001,momentum:float=0,weight_decay:float=0,dampening:float=0,nesterov:bool=False,lr_scheduler:Optional[LRScheduler]=None,):super().__init__(lr,lr_scheduler)self.momentum=momentumself.weight_decay=weight_decayself.dampening=dampeningself.nesterov=nesterov
[docs]definit_optimizer(self,params:Iterable)->None:"""Initialize the torch optimizer wrapped by this class. Parameters ---------- params : An iterable of ``torch.Tensor`` or ``dict``. Specifies what Tensors should be optimized. """self.torch_optimizer=torch_SGD(params=params,lr=self.lr,momentum=self.momentum,weight_decay=self.weight_decay,dampening=self.dampening,nesterov=self.nesterov,)ifself.lr_schedulerisnotNone:self.lr_scheduler.init_scheduler(self.torch_optimizer)