Source code for pypots.optim.lr_scheduler.multiplicative_lrs
"""Multiplicative learning rate scheduler."""# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-Clausefrom.baseimportLRScheduler,logger
[docs]classMultiplicativeLR(LRScheduler):"""Multiply the learning rate of each parameter group by the factor given in the specified function. When last_epoch=-1, sets initial lr as lr. Parameters ---------- lr_lambda: Callable or list, A function which computes a multiplicative factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups. last_epoch: int, The index of last epoch. Default: -1. verbose: bool, If ``True``, prints a message to stdout for each update. Default: ``False``. Notes ----- This class works the same with ``torch.optim.lr_scheduler.MultiplicativeLR``. The only difference that is also why we implement them is that you don't have to pass according optimizers into them immediately while initializing them. Example ------- >>> lmbda = lambda epoch: 0.95 >>> # xdoctest: +SKIP >>> scheduler = MultiplicativeLR(lr_lambda=lmbda) >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) """def__init__(self,lr_lambda,last_epoch=-1,verbose=False):super().__init__(last_epoch,verbose)self.lr_lambda=lr_lambdaself.lr_lambdas=None
[docs]definit_scheduler(self,optimizer):ifnotisinstance(self.lr_lambda,list)andnotisinstance(self.lr_lambda,tuple):self.lr_lambdas=[self.lr_lambda]*len(optimizer.param_groups)else:iflen(self.lr_lambda)!=len(optimizer.param_groups):raiseValueError("Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups),len(self.lr_lambda)))self.lr_lambdas=list(self.lr_lambda)super().init_scheduler(optimizer)
[docs]defget_lr(self):ifnotself._get_lr_called_within_step:logger.warning("⚠️ To get the last learning rate computed by the scheduler, please use `get_last_lr()`.",)ifself.last_epoch>0:return[group["lr"]*lmbda(self.last_epoch)forlmbda,groupinzip(self.lr_lambdas,self.optimizer.param_groups)]else:return[group["lr"]forgroupinself.optimizer.param_groups]