123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from mmengine.optim.scheduler.lr_scheduler import LRSchedulerMixin
- from mmengine.optim.scheduler.momentum_scheduler import MomentumSchedulerMixin
- from mmengine.optim.scheduler.param_scheduler import INF, _ParamScheduler
- from torch.optim import Optimizer
- from mmdet.registry import PARAM_SCHEDULERS
- @PARAM_SCHEDULERS.register_module()
- class QuadraticWarmupParamScheduler(_ParamScheduler):
- r"""Warm up the parameter value of each parameter group by quadratic
- formula:
- .. math::
- X_{t} = X_{t-1} + \frac{2t+1}{{(end-begin)}^{2}} \times X_{base}
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- param_name (str): Name of the parameter to be adjusted, such as
- ``lr``, ``momentum``.
- begin (int): Step at which to start updating the parameters.
- Defaults to 0.
- end (int): Step at which to stop updating the parameters.
- Defaults to INF.
- last_step (int): The index of last step. Used for resume without
- state dict. Defaults to -1.
- by_epoch (bool): Whether the scheduled parameters are updated by
- epochs. Defaults to True.
- verbose (bool): Whether to print the value for each update.
- Defaults to False.
- """
- def __init__(self,
- optimizer: Optimizer,
- param_name: str,
- begin: int = 0,
- end: int = INF,
- last_step: int = -1,
- by_epoch: bool = True,
- verbose: bool = False):
- if end >= INF:
- raise ValueError('``end`` must be less than infinity,'
- 'Please set ``end`` parameter of '
- '``QuadraticWarmupScheduler`` as the '
- 'number of warmup end.')
- self.total_iters = end - begin
- super().__init__(
- optimizer=optimizer,
- param_name=param_name,
- begin=begin,
- end=end,
- last_step=last_step,
- by_epoch=by_epoch,
- verbose=verbose)
- @classmethod
- def build_iter_from_epoch(cls,
- *args,
- begin=0,
- end=INF,
- by_epoch=True,
- epoch_length=None,
- **kwargs):
- """Build an iter-based instance of this scheduler from an epoch-based
- config."""
- assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
- 'be converted to iter-based.'
- assert epoch_length is not None and epoch_length > 0, \
- f'`epoch_length` must be a positive integer, ' \
- f'but got {epoch_length}.'
- by_epoch = False
- begin = begin * epoch_length
- if end != INF:
- end = end * epoch_length
- return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
- def _get_value(self):
- """Compute value using chainable form of the scheduler."""
- if self.last_step == 0:
- return [
- base_value * (2 * self.last_step + 1) / self.total_iters**2
- for base_value in self.base_values
- ]
- return [
- group[self.param_name] + base_value *
- (2 * self.last_step + 1) / self.total_iters**2
- for base_value, group in zip(self.base_values,
- self.optimizer.param_groups)
- ]
- @PARAM_SCHEDULERS.register_module()
- class QuadraticWarmupLR(LRSchedulerMixin, QuadraticWarmupParamScheduler):
- """Warm up the learning rate of each parameter group by quadratic formula.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- begin (int): Step at which to start updating the parameters.
- Defaults to 0.
- end (int): Step at which to stop updating the parameters.
- Defaults to INF.
- last_step (int): The index of last step. Used for resume without
- state dict. Defaults to -1.
- by_epoch (bool): Whether the scheduled parameters are updated by
- epochs. Defaults to True.
- verbose (bool): Whether to print the value for each update.
- Defaults to False.
- """
- @PARAM_SCHEDULERS.register_module()
- class QuadraticWarmupMomentum(MomentumSchedulerMixin,
- QuadraticWarmupParamScheduler):
- """Warm up the momentum value of each parameter group by quadratic formula.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- begin (int): Step at which to start updating the parameters.
- Defaults to 0.
- end (int): Step at which to stop updating the parameters.
- Defaults to INF.
- last_step (int): The index of last step. Used for resume without
- state dict. Defaults to -1.
- by_epoch (bool): Whether the scheduled parameters are updated by
- epochs. Defaults to True.
- verbose (bool): Whether to print the value for each update.
- Defaults to False.
- """
|