test_pipeline_switch_hook.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. from unittest.mock import Mock
  4. from mmdet.engine.hooks import PipelineSwitchHook
  5. class TestPipelineSwitchHook(TestCase):
  6. def test_persistent_workers_on(self):
  7. runner = Mock()
  8. runner.model = Mock()
  9. runner.model.module = Mock()
  10. runner.train_dataloader = Mock()
  11. runner.train_dataloader.persistent_workers = True
  12. runner.train_dataloader._DataLoader__initialized = True
  13. stage2 = [dict(type='RandomResize', scale=(1280, 1280))]
  14. runner.epoch = 284 # epoch < switch_epoch
  15. hook = PipelineSwitchHook(switch_epoch=285, switch_pipeline=stage2)
  16. hook.before_train_epoch(runner)
  17. self.assertFalse(hook._restart_dataloader)
  18. self.assertTrue(runner.train_dataloader._DataLoader__initialized)
  19. runner.epoch = 285 # epoch == switch_epoch
  20. hook.before_train_epoch(runner)
  21. self.assertTrue(hook._restart_dataloader)
  22. self.assertFalse(runner.train_dataloader._DataLoader__initialized)
  23. self.assertTrue(
  24. len(runner.train_dataloader.dataset.pipeline.transforms) == 1)
  25. runner.epoch = 286 # epoch > switch_epoch
  26. hook.before_train_epoch(runner)
  27. self.assertTrue(runner.train_dataloader._DataLoader__initialized)
  28. self.assertTrue(
  29. len(runner.train_dataloader.dataset.pipeline.transforms) == 1)
  30. def test_persistent_workers_off(self):
  31. runner = Mock()
  32. runner.model = Mock()
  33. runner.train_dataloader = Mock()
  34. runner.train_dataloader.persistent_workers = False
  35. runner.train_dataloader._DataLoader__initialized = True
  36. stage2 = [dict(type='RandomResize', scale=(1280, 1280))]
  37. runner.epoch = 284 # epoch < switch_epoch
  38. hook = PipelineSwitchHook(switch_epoch=285, switch_pipeline=stage2)
  39. hook.before_train_epoch(runner)
  40. self.assertFalse(hook._restart_dataloader)
  41. self.assertTrue(runner.train_dataloader._DataLoader__initialized)
  42. runner.epoch = 285 # epoch == switch_epoch
  43. hook.before_train_epoch(runner)
  44. self.assertFalse(hook._restart_dataloader)
  45. self.assertTrue(runner.train_dataloader._DataLoader__initialized)
  46. self.assertTrue(
  47. len(runner.train_dataloader.dataset.pipeline.transforms) == 1)
  48. runner.epoch = 286 # epoch > switch_epoch
  49. hook.before_train_epoch(runner)
  50. self.assertTrue(runner.train_dataloader._DataLoader__initialized)
  51. self.assertTrue(
  52. len(runner.train_dataloader.dataset.pipeline.transforms) == 1)
  53. def test_initialize_after_switching(self):
  54. # This simulates the resumption after the switching.
  55. runner = Mock()
  56. runner.model = Mock()
  57. runner.model.module = Mock()
  58. runner.train_dataloader = Mock()
  59. runner.train_dataloader.persistent_workers = True
  60. runner.train_dataloader._DataLoader__initialized = True
  61. stage2 = [dict(type='RandomResize', scale=(1280, 1280))]
  62. runner.epoch = 286 # epoch > switch_epoch
  63. hook = PipelineSwitchHook(switch_epoch=285, switch_pipeline=stage2)
  64. hook.before_train_epoch(runner)
  65. self.assertTrue(hook._restart_dataloader)
  66. self.assertFalse(runner.train_dataloader._DataLoader__initialized)
  67. self.assertTrue(
  68. len(runner.train_dataloader.dataset.pipeline.transforms) == 1)