"""The optimizer wrapper for PyTorch Adagrad."""# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-ClausefromtypingimportIterable,Optionalfromtorch.optimimportAdagradastorch_Adagradfrom.baseimportOptimizerfrom.lr_scheduler.baseimportLRScheduler
[docs]classAdagrad(Optimizer):"""The optimizer wrapper for PyTorch Adagrad :class:`torch.optim.Adagrad`. Parameters ---------- lr : float The learning rate of the optimizer. lr_decay : float Learning rate decay. weight_decay : float Weight decay (L2 penalty). eps : float Term added to the denominator to improve numerical stability. initial_accumulator_value : float A floating point value. Starting value for the accumulators, must be positive. lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler The learning rate scheduler of the optimizer. """def__init__(self,lr:float=0.01,lr_decay:float=0,weight_decay:float=0.01,initial_accumulator_value:float=0.01,# it is set as 0 in the torch implementation, but delta shouldn't be 0eps:float=1e-08,lr_scheduler:Optional[LRScheduler]=None,):super().__init__(lr,lr_scheduler)self.lr_decay=lr_decayself.weight_decay=weight_decayself.initial_accumulator_value=initial_accumulator_valueself.eps=eps
[docs]definit_optimizer(self,params:Iterable)->None:"""Initialize the torch optimizer wrapped by this class. Parameters ---------- params : An iterable of ``torch.Tensor`` or ``dict``. Specifies what Tensors should be optimized. """self.torch_optimizer=torch_Adagrad(params=params,lr=self.lr,lr_decay=self.lr_decay,weight_decay=self.weight_decay,initial_accumulator_value=self.initial_accumulator_value,eps=self.eps,)ifself.lr_schedulerisnotNone:self.lr_scheduler.init_scheduler(self.torch_optimizer)