yolox_mode_switch_hook.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Sequence
  3. from mmengine.hooks import Hook
  4. from mmengine.model import is_model_wrapper
  5. from mmdet.registry import HOOKS
  6. @HOOKS.register_module()
  7. class YOLOXModeSwitchHook(Hook):
  8. """Switch the mode of YOLOX during training.
  9. This hook turns off the mosaic and mixup data augmentation and switches
  10. to use L1 loss in bbox_head.
  11. Args:
  12. num_last_epochs (int): The number of latter epochs in the end of the
  13. training to close the data augmentation and switch to L1 loss.
  14. Defaults to 15.
  15. skip_type_keys (Sequence[str], optional): Sequence of type string to be
  16. skip pipeline. Defaults to ('Mosaic', 'RandomAffine', 'MixUp').
  17. """
  18. def __init__(
  19. self,
  20. num_last_epochs: int = 15,
  21. skip_type_keys: Sequence[str] = ('Mosaic', 'RandomAffine', 'MixUp')
  22. ) -> None:
  23. self.num_last_epochs = num_last_epochs
  24. self.skip_type_keys = skip_type_keys
  25. self._restart_dataloader = False
  26. self._has_switched = False
  27. def before_train_epoch(self, runner) -> None:
  28. """Close mosaic and mixup augmentation and switches to use L1 loss."""
  29. epoch = runner.epoch
  30. train_loader = runner.train_dataloader
  31. model = runner.model
  32. # TODO: refactor after mmengine using model wrapper
  33. if is_model_wrapper(model):
  34. model = model.module
  35. epoch_to_be_switched = ((epoch + 1) >=
  36. runner.max_epochs - self.num_last_epochs)
  37. if epoch_to_be_switched and not self._has_switched:
  38. runner.logger.info('No mosaic and mixup aug now!')
  39. # The dataset pipeline cannot be updated when persistent_workers
  40. # is True, so we need to force the dataloader's multi-process
  41. # restart. This is a very hacky approach.
  42. train_loader.dataset.update_skip_type_keys(self.skip_type_keys)
  43. if hasattr(train_loader, 'persistent_workers'
  44. ) and train_loader.persistent_workers is True:
  45. train_loader._DataLoader__initialized = False
  46. train_loader._iterator = None
  47. self._restart_dataloader = True
  48. runner.logger.info('Add additional L1 loss now!')
  49. if hasattr(model, 'detector'):
  50. model.detector.bbox_head.use_l1 = True
  51. else:
  52. model.bbox_head.use_l1 = True
  53. self._has_switched = True
  54. else:
  55. # Once the restart is complete, we need to restore
  56. # the initialization flag.
  57. if self._restart_dataloader:
  58. train_loader._DataLoader__initialized = True