1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- from unittest.mock import Mock, patch
- from mmdet.engine.hooks import YOLOXModeSwitchHook
- class TestYOLOXModeSwitchHook(TestCase):
- @patch('mmdet.engine.hooks.yolox_mode_switch_hook.is_model_wrapper')
- def test_is_model_wrapper_and_persistent_workers_on(
- self, mock_is_model_wrapper):
- mock_is_model_wrapper.return_value = True
- runner = Mock()
- runner.model = Mock()
- runner.model.module = Mock()
- runner.model.module.detector.bbox_head.use_l1 = False
- runner.train_dataloader = Mock()
- runner.train_dataloader.persistent_workers = True
- runner.train_dataloader._DataLoader__initialized = True
- runner.epoch = 284
- runner.max_epochs = 300
- hook = YOLOXModeSwitchHook(num_last_epochs=15)
- hook.before_train_epoch(runner)
- self.assertTrue(hook._restart_dataloader)
- self.assertTrue(runner.model.module.detector.bbox_head.use_l1)
- self.assertFalse(runner.train_dataloader._DataLoader__initialized)
- runner.epoch = 285
- hook.before_train_epoch(runner)
- self.assertTrue(runner.train_dataloader._DataLoader__initialized)
- def test_not_model_wrapper_and_persistent_workers_off(self):
- runner = Mock()
- runner.model = Mock()
- runner.model.detector.bbox_head.use_l1 = False
- runner.train_dataloader = Mock()
- runner.train_dataloader.persistent_workers = False
- runner.train_dataloader._DataLoader__initialized = True
- runner.epoch = 284
- runner.max_epochs = 300
- hook = YOLOXModeSwitchHook(num_last_epochs=15)
- hook.before_train_epoch(runner)
- self.assertFalse(hook._restart_dataloader)
- self.assertTrue(runner.model.detector.bbox_head.use_l1)
- self.assertTrue(runner.train_dataloader._DataLoader__initialized)
- runner.epoch = 285
- hook.before_train_epoch(runner)
- self.assertFalse(hook._restart_dataloader)
- self.assertTrue(runner.train_dataloader._DataLoader__initialized)
- @patch('mmdet.engine.hooks.yolox_mode_switch_hook.is_model_wrapper')
- def test_initialize_after_switching(self, mock_is_model_wrapper):
- # This simulates the resumption after the switching.
- mock_is_model_wrapper.return_value = True
- runner = Mock()
- runner.model = Mock()
- runner.model.module = Mock()
- runner.model.module.bbox_head.use_l1 = False
- runner.train_dataloader = Mock()
- runner.train_dataloader.persistent_workers = True
- runner.train_dataloader._DataLoader__initialized = True
- runner.epoch = 285
- runner.max_epochs = 300
- # epoch + 1 > max_epochs - num_last_epochs .
- hook = YOLOXModeSwitchHook(num_last_epochs=15)
- hook.before_train_epoch(runner)
- self.assertTrue(hook._restart_dataloader)
- self.assertTrue(runner.model.module.detector.bbox_head.use_l1)
- self.assertFalse(runner.train_dataloader._DataLoader__initialized)
|