Source code for pypots.optim.lr_scheduler.multistep_lrs
"""Multistep learning rate scheduler."""# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-Clausefrombisectimportbisect_rightfromcollectionsimportCounterfrom.baseimportLRScheduler,logger
[docs]classMultiStepLR(LRScheduler):"""Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. Notice that such decay can happen simultaneously with other changes to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. Parameters ---------- milestones: list, List of epoch indices. Must be increasing. gamma: float, default=0.1, Multiplicative factor of learning rate decay. last_epoch: int The index of last epoch. Default: -1. verbose: bool If ``True``, prints a message to stdout for each update. Default: ``False``. Notes ----- This class works the same with ``torch.optim.lr_scheduler.MultiStepLR``. The only difference that is also why we implement them is that you don't have to pass according optimizers into them immediately while initializing them. Example ------- >>> # Assuming optimizer uses lr = 0.05 for all groups >>> # lr = 0.05 if epoch < 30 >>> # lr = 0.005 if 30 <= epoch < 80 >>> # lr = 0.0005 if epoch >= 80 >>> # xdoctest: +SKIP >>> scheduler = MultiStepLR(milestones=[30,80], gamma=0.1) >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) """def__init__(self,milestones,gamma=0.1,last_epoch=-1,verbose=False):super().__init__(last_epoch,verbose)self.milestones=Counter(milestones)self.gamma=gamma
[docs]defget_lr(self):ifnotself._get_lr_called_within_step:logger.warning("⚠️ To get the last learning rate computed by the scheduler, please use `get_last_lr()`.",)ifself.last_epochnotinself.milestones:return[group["lr"]forgroupinself.optimizer.param_groups]return[group["lr"]*self.gamma**self.milestones[self.last_epoch]forgroupinself.optimizer.param_groups]