fp16_compression_hook.py 920 B

12345678910111213141516171819202122232425
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmengine.hooks import Hook
  3. from mmdet.registry import HOOKS
  4. @HOOKS.register_module()
  5. class Fp16CompresssionHook(Hook):
  6. """Support fp16 compression in DDP mode.
  7. In detectron2, vitdet use Fp16CompresssionHook in training process
  8. Fp16CompresssionHook can reduce training time and improve bbox mAP when you
  9. use Fp16CompresssionHook, training time reduce form 3 days to 2 days and
  10. box mAP from 51.4 to 51.6
  11. """
  12. def before_train(self, runner):
  13. if runner.distributed:
  14. if runner.cfg.get('model_wrapper_cfg') is None:
  15. from torch.distributed.algorithms.ddp_comm_hooks import \
  16. default as comm_hooks
  17. runner.model.register_comm_hook(
  18. state=None, hook=comm_hooks.fp16_compress_hook)
  19. runner.logger.info('use fp16 compression in DDP mode')