layer_decay_optimizer_constructor.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import json
  3. from typing import List
  4. import torch.nn as nn
  5. from mmengine.dist import get_dist_info
  6. from mmengine.logging import MMLogger
  7. from mmengine.optim import DefaultOptimWrapperConstructor
  8. from mmdet.registry import OPTIM_WRAPPER_CONSTRUCTORS
  9. def get_layer_id_for_vit(var_name, max_layer_id):
  10. """Get the layer id to set the different learning rates in ``layer_wise``
  11. decay_type.
  12. Args:
  13. var_name (str): The key of the model.
  14. max_layer_id (int): Maximum layer id.
  15. Returns:
  16. int: The id number corresponding to different learning rate in
  17. ``LayerDecayOptimizerConstructor``.
  18. """
  19. if var_name.startswith('backbone'):
  20. if 'patch_embed' in var_name or 'pos_embed' in var_name:
  21. return 0
  22. elif '.blocks.' in var_name:
  23. layer_id = int(var_name.split('.')[2]) + 1
  24. return layer_id
  25. else:
  26. return max_layer_id + 1
  27. else:
  28. return max_layer_id + 1
  29. @OPTIM_WRAPPER_CONSTRUCTORS.register_module()
  30. class LayerDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
  31. # Different learning rates are set for different layers of backbone.
  32. # Note: Currently, this optimizer constructor is built for ViT.
  33. def add_params(self, params: List[dict], module: nn.Module,
  34. **kwargs) -> None:
  35. """Add all parameters of module to the params list.
  36. The parameters of the given module will be added to the list of param
  37. groups, with specific rules defined by paramwise_cfg.
  38. Args:
  39. params (list[dict]): A list of param groups, it will be modified
  40. in place.
  41. module (nn.Module): The module to be added.
  42. """
  43. logger = MMLogger.get_current_instance()
  44. parameter_groups = {}
  45. logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
  46. num_layers = self.paramwise_cfg.get('num_layers') + 2
  47. decay_rate = self.paramwise_cfg.get('decay_rate')
  48. decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
  49. logger.info('Build LayerDecayOptimizerConstructor '
  50. f'{decay_type} {decay_rate} - {num_layers}')
  51. weight_decay = self.base_wd
  52. for name, param in module.named_parameters():
  53. if not param.requires_grad:
  54. continue # frozen weights
  55. if name.startswith('backbone.blocks') and 'norm' in name:
  56. group_name = 'no_decay'
  57. this_weight_decay = 0.
  58. elif 'pos_embed' in name:
  59. group_name = 'no_decay_pos_embed'
  60. this_weight_decay = 0
  61. else:
  62. group_name = 'decay'
  63. this_weight_decay = weight_decay
  64. layer_id = get_layer_id_for_vit(
  65. name, self.paramwise_cfg.get('num_layers'))
  66. logger.info(f'set param {name} as id {layer_id}')
  67. group_name = f'layer_{layer_id}_{group_name}'
  68. this_lr_multi = 1.
  69. if group_name not in parameter_groups:
  70. scale = decay_rate**(num_layers - 1 - layer_id)
  71. parameter_groups[group_name] = {
  72. 'weight_decay': this_weight_decay,
  73. 'params': [],
  74. 'param_names': [],
  75. 'lr_scale': scale,
  76. 'group_name': group_name,
  77. 'lr': scale * self.base_lr * this_lr_multi,
  78. }
  79. parameter_groups[group_name]['params'].append(param)
  80. parameter_groups[group_name]['param_names'].append(name)
  81. rank, _ = get_dist_info()
  82. if rank == 0:
  83. to_display = {}
  84. for key in parameter_groups:
  85. to_display[key] = {
  86. 'param_names': parameter_groups[key]['param_names'],
  87. 'lr_scale': parameter_groups[key]['lr_scale'],
  88. 'lr': parameter_groups[key]['lr'],
  89. 'weight_decay': parameter_groups[key]['weight_decay'],
  90. }
  91. logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
  92. params.extend(parameter_groups.values())