quadratic_warmup.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmengine.optim.scheduler.lr_scheduler import LRSchedulerMixin
  3. from mmengine.optim.scheduler.momentum_scheduler import MomentumSchedulerMixin
  4. from mmengine.optim.scheduler.param_scheduler import INF, _ParamScheduler
  5. from torch.optim import Optimizer
  6. from mmdet.registry import PARAM_SCHEDULERS
  7. @PARAM_SCHEDULERS.register_module()
  8. class QuadraticWarmupParamScheduler(_ParamScheduler):
  9. r"""Warm up the parameter value of each parameter group by quadratic
  10. formula:
  11. .. math::
  12. X_{t} = X_{t-1} + \frac{2t+1}{{(end-begin)}^{2}} \times X_{base}
  13. Args:
  14. optimizer (Optimizer): Wrapped optimizer.
  15. param_name (str): Name of the parameter to be adjusted, such as
  16. ``lr``, ``momentum``.
  17. begin (int): Step at which to start updating the parameters.
  18. Defaults to 0.
  19. end (int): Step at which to stop updating the parameters.
  20. Defaults to INF.
  21. last_step (int): The index of last step. Used for resume without
  22. state dict. Defaults to -1.
  23. by_epoch (bool): Whether the scheduled parameters are updated by
  24. epochs. Defaults to True.
  25. verbose (bool): Whether to print the value for each update.
  26. Defaults to False.
  27. """
  28. def __init__(self,
  29. optimizer: Optimizer,
  30. param_name: str,
  31. begin: int = 0,
  32. end: int = INF,
  33. last_step: int = -1,
  34. by_epoch: bool = True,
  35. verbose: bool = False):
  36. if end >= INF:
  37. raise ValueError('``end`` must be less than infinity,'
  38. 'Please set ``end`` parameter of '
  39. '``QuadraticWarmupScheduler`` as the '
  40. 'number of warmup end.')
  41. self.total_iters = end - begin
  42. super().__init__(
  43. optimizer=optimizer,
  44. param_name=param_name,
  45. begin=begin,
  46. end=end,
  47. last_step=last_step,
  48. by_epoch=by_epoch,
  49. verbose=verbose)
  50. @classmethod
  51. def build_iter_from_epoch(cls,
  52. *args,
  53. begin=0,
  54. end=INF,
  55. by_epoch=True,
  56. epoch_length=None,
  57. **kwargs):
  58. """Build an iter-based instance of this scheduler from an epoch-based
  59. config."""
  60. assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
  61. 'be converted to iter-based.'
  62. assert epoch_length is not None and epoch_length > 0, \
  63. f'`epoch_length` must be a positive integer, ' \
  64. f'but got {epoch_length}.'
  65. by_epoch = False
  66. begin = begin * epoch_length
  67. if end != INF:
  68. end = end * epoch_length
  69. return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
  70. def _get_value(self):
  71. """Compute value using chainable form of the scheduler."""
  72. if self.last_step == 0:
  73. return [
  74. base_value * (2 * self.last_step + 1) / self.total_iters**2
  75. for base_value in self.base_values
  76. ]
  77. return [
  78. group[self.param_name] + base_value *
  79. (2 * self.last_step + 1) / self.total_iters**2
  80. for base_value, group in zip(self.base_values,
  81. self.optimizer.param_groups)
  82. ]
  83. @PARAM_SCHEDULERS.register_module()
  84. class QuadraticWarmupLR(LRSchedulerMixin, QuadraticWarmupParamScheduler):
  85. """Warm up the learning rate of each parameter group by quadratic formula.
  86. Args:
  87. optimizer (Optimizer): Wrapped optimizer.
  88. begin (int): Step at which to start updating the parameters.
  89. Defaults to 0.
  90. end (int): Step at which to stop updating the parameters.
  91. Defaults to INF.
  92. last_step (int): The index of last step. Used for resume without
  93. state dict. Defaults to -1.
  94. by_epoch (bool): Whether the scheduled parameters are updated by
  95. epochs. Defaults to True.
  96. verbose (bool): Whether to print the value for each update.
  97. Defaults to False.
  98. """
  99. @PARAM_SCHEDULERS.register_module()
  100. class QuadraticWarmupMomentum(MomentumSchedulerMixin,
  101. QuadraticWarmupParamScheduler):
  102. """Warm up the momentum value of each parameter group by quadratic formula.
  103. Args:
  104. optimizer (Optimizer): Wrapped optimizer.
  105. begin (int): Step at which to start updating the parameters.
  106. Defaults to 0.
  107. end (int): Step at which to stop updating the parameters.
  108. Defaults to INF.
  109. last_step (int): The index of last step. Used for resume without
  110. state dict. Defaults to -1.
  111. by_epoch (bool): Whether the scheduled parameters are updated by
  112. epochs. Defaults to True.
  113. verbose (bool): Whether to print the value for each update.
  114. Defaults to False.
  115. """