Source code for mipcandy.common.optim.lr_scheduler

from typing import override

from torch import optim


[docs] class AbsoluteLinearLR(optim.lr_scheduler.LRScheduler): """ lr = kx + b """ def __init__(self, optimizer: optim.Optimizer, k: float, b: float, *, min_lr: float = 1e-6, restart: bool = False, last_epoch: int = -1) -> None: self._k: float = k self._b: float = b if min_lr < 0: raise ValueError(f"`min_lr` must be positive, but got {min_lr}") self._min_lr: float = min_lr self._restart: bool = restart self._restart_step: int = 0 super().__init__(optimizer, last_epoch)
[docs] def _interp(self, epoch: int) -> float: epoch -= self._restart_step r = self._k * epoch + self._b if r < self._min_lr: if self._restart: self._restart_step = epoch return self._interp(epoch) return self._min_lr return r
[docs] @override def get_lr(self) -> list[float]: target = self._interp(self.last_epoch) return [target for _ in self.optimizer.param_groups]
[docs] class PolyLRScheduler(optim.lr_scheduler.LRScheduler): def __init__(self, optimizer: optim.Optimizer, initial_lr: float, max_steps: int, *, exponent: float = .9, last_epoch: int = -1) -> None: self._initial_lr: float = initial_lr self._max_steps: int = max_steps self._exponent: float = exponent super().__init__(optimizer, last_epoch)
[docs] def _interp(self, epoch: int) -> float: return self._initial_lr * (1 - epoch / self._max_steps) ** self._exponent
[docs] @override def get_lr(self) -> list[float]: target = self._interp(self.last_epoch) return [target for _ in self.optimizer.param_groups]