test_yolox_mode_switch_hook.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. from unittest.mock import Mock, patch
  4. from mmdet.engine.hooks import YOLOXModeSwitchHook
  5. class TestYOLOXModeSwitchHook(TestCase):
  6. @patch('mmdet.engine.hooks.yolox_mode_switch_hook.is_model_wrapper')
  7. def test_is_model_wrapper_and_persistent_workers_on(
  8. self, mock_is_model_wrapper):
  9. mock_is_model_wrapper.return_value = True
  10. runner = Mock()
  11. runner.model = Mock()
  12. runner.model.module = Mock()
  13. runner.model.module.detector.bbox_head.use_l1 = False
  14. runner.train_dataloader = Mock()
  15. runner.train_dataloader.persistent_workers = True
  16. runner.train_dataloader._DataLoader__initialized = True
  17. runner.epoch = 284
  18. runner.max_epochs = 300
  19. hook = YOLOXModeSwitchHook(num_last_epochs=15)
  20. hook.before_train_epoch(runner)
  21. self.assertTrue(hook._restart_dataloader)
  22. self.assertTrue(runner.model.module.detector.bbox_head.use_l1)
  23. self.assertFalse(runner.train_dataloader._DataLoader__initialized)
  24. runner.epoch = 285
  25. hook.before_train_epoch(runner)
  26. self.assertTrue(runner.train_dataloader._DataLoader__initialized)
  27. def test_not_model_wrapper_and_persistent_workers_off(self):
  28. runner = Mock()
  29. runner.model = Mock()
  30. runner.model.detector.bbox_head.use_l1 = False
  31. runner.train_dataloader = Mock()
  32. runner.train_dataloader.persistent_workers = False
  33. runner.train_dataloader._DataLoader__initialized = True
  34. runner.epoch = 284
  35. runner.max_epochs = 300
  36. hook = YOLOXModeSwitchHook(num_last_epochs=15)
  37. hook.before_train_epoch(runner)
  38. self.assertFalse(hook._restart_dataloader)
  39. self.assertTrue(runner.model.detector.bbox_head.use_l1)
  40. self.assertTrue(runner.train_dataloader._DataLoader__initialized)
  41. runner.epoch = 285
  42. hook.before_train_epoch(runner)
  43. self.assertFalse(hook._restart_dataloader)
  44. self.assertTrue(runner.train_dataloader._DataLoader__initialized)
  45. @patch('mmdet.engine.hooks.yolox_mode_switch_hook.is_model_wrapper')
  46. def test_initialize_after_switching(self, mock_is_model_wrapper):
  47. # This simulates the resumption after the switching.
  48. mock_is_model_wrapper.return_value = True
  49. runner = Mock()
  50. runner.model = Mock()
  51. runner.model.module = Mock()
  52. runner.model.module.bbox_head.use_l1 = False
  53. runner.train_dataloader = Mock()
  54. runner.train_dataloader.persistent_workers = True
  55. runner.train_dataloader._DataLoader__initialized = True
  56. runner.epoch = 285
  57. runner.max_epochs = 300
  58. # epoch + 1 > max_epochs - num_last_epochs .
  59. hook = YOLOXModeSwitchHook(num_last_epochs=15)
  60. hook.before_train_epoch(runner)
  61. self.assertTrue(hook._restart_dataloader)
  62. self.assertTrue(runner.model.module.detector.bbox_head.use_l1)
  63. self.assertFalse(runner.train_dataloader._DataLoader__initialized)