Source code for pypots.optim.lr_scheduler.constant_lrs
"""Constant learning rate scheduler."""# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-Clausefrom.baseimportLRScheduler,logger
[docs]classConstantLR(LRScheduler):"""Decays the learning rate of each parameter group by a small constant factor until the number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can happen simultaneously with other changes to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. Parameters ---------- factor: float, default=1./3. The number we multiply learning rate until the milestone. total_iters: int, default=5, The number of steps that the scheduler decays the learning rate. last_epoch: int The index of last epoch. Default: -1. verbose: bool If ``True``, prints a message to stdout for each update. Default: ``False``. Notes ----- This class works the same with ``torch.optim.lr_scheduler.ConstantLR``. The only difference that is also why we implement them is that you don't have to pass according optimizers into them immediately while initializing them. Example ------- >>> # Assuming optimizer uses lr = 0.05 for all groups >>> # lr = 0.025 if epoch == 0 >>> # lr = 0.025 if epoch == 1 >>> # lr = 0.025 if epoch == 2 >>> # lr = 0.025 if epoch == 3 >>> # lr = 0.05 if epoch >= 4 >>> # xdoctest: +SKIP >>> scheduler = ConstantLR(factor=0.5, total_iters=4) >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) """def__init__(self,factor=1.0/3,total_iters=5,last_epoch=-1,verbose=False):super().__init__(last_epoch,verbose)iffactor>1.0orfactor<0:raiseValueError("Constant multiplicative factor expected to be between 0 and 1.")self.factor=factorself.total_iters=total_iters
[docs]defget_lr(self):ifnotself._get_lr_called_within_step:logger.warning("⚠️ To get the last learning rate computed by the scheduler, please use `get_last_lr()`.",)ifself.last_epoch==0:return[group["lr"]*self.factorforgroupinself.optimizer.param_groups]ifself.last_epoch>self.total_itersor(self.last_epoch!=self.total_iters):return[group["lr"]forgroupinself.optimizer.param_groups]ifself.last_epoch==self.total_iters:return[group["lr"]*(1.0/self.factor)forgroupinself.optimizer.param_groups]