123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- from unittest.mock import Mock
- from mmdet.engine.hooks import PipelineSwitchHook
- class TestPipelineSwitchHook(TestCase):
- def test_persistent_workers_on(self):
- runner = Mock()
- runner.model = Mock()
- runner.model.module = Mock()
- runner.train_dataloader = Mock()
- runner.train_dataloader.persistent_workers = True
- runner.train_dataloader._DataLoader__initialized = True
- stage2 = [dict(type='RandomResize', scale=(1280, 1280))]
- runner.epoch = 284 # epoch < switch_epoch
- hook = PipelineSwitchHook(switch_epoch=285, switch_pipeline=stage2)
- hook.before_train_epoch(runner)
- self.assertFalse(hook._restart_dataloader)
- self.assertTrue(runner.train_dataloader._DataLoader__initialized)
- runner.epoch = 285 # epoch == switch_epoch
- hook.before_train_epoch(runner)
- self.assertTrue(hook._restart_dataloader)
- self.assertFalse(runner.train_dataloader._DataLoader__initialized)
- self.assertTrue(
- len(runner.train_dataloader.dataset.pipeline.transforms) == 1)
- runner.epoch = 286 # epoch > switch_epoch
- hook.before_train_epoch(runner)
- self.assertTrue(runner.train_dataloader._DataLoader__initialized)
- self.assertTrue(
- len(runner.train_dataloader.dataset.pipeline.transforms) == 1)
- def test_persistent_workers_off(self):
- runner = Mock()
- runner.model = Mock()
- runner.train_dataloader = Mock()
- runner.train_dataloader.persistent_workers = False
- runner.train_dataloader._DataLoader__initialized = True
- stage2 = [dict(type='RandomResize', scale=(1280, 1280))]
- runner.epoch = 284 # epoch < switch_epoch
- hook = PipelineSwitchHook(switch_epoch=285, switch_pipeline=stage2)
- hook.before_train_epoch(runner)
- self.assertFalse(hook._restart_dataloader)
- self.assertTrue(runner.train_dataloader._DataLoader__initialized)
- runner.epoch = 285 # epoch == switch_epoch
- hook.before_train_epoch(runner)
- self.assertFalse(hook._restart_dataloader)
- self.assertTrue(runner.train_dataloader._DataLoader__initialized)
- self.assertTrue(
- len(runner.train_dataloader.dataset.pipeline.transforms) == 1)
- runner.epoch = 286 # epoch > switch_epoch
- hook.before_train_epoch(runner)
- self.assertTrue(runner.train_dataloader._DataLoader__initialized)
- self.assertTrue(
- len(runner.train_dataloader.dataset.pipeline.transforms) == 1)
- def test_initialize_after_switching(self):
- # This simulates the resumption after the switching.
- runner = Mock()
- runner.model = Mock()
- runner.model.module = Mock()
- runner.train_dataloader = Mock()
- runner.train_dataloader.persistent_workers = True
- runner.train_dataloader._DataLoader__initialized = True
- stage2 = [dict(type='RandomResize', scale=(1280, 1280))]
- runner.epoch = 286 # epoch > switch_epoch
- hook = PipelineSwitchHook(switch_epoch=285, switch_pipeline=stage2)
- hook.before_train_epoch(runner)
- self.assertTrue(hook._restart_dataloader)
- self.assertFalse(runner.train_dataloader._DataLoader__initialized)
- self.assertTrue(
- len(runner.train_dataloader.dataset.pipeline.transforms) == 1)
|