# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock

from mmdet.engine.hooks import PipelineSwitchHook


class TestPipelineSwitchHook(TestCase):

    def test_persistent_workers_on(self):
        runner = Mock()
        runner.model = Mock()
        runner.model.module = Mock()
        runner.train_dataloader = Mock()
        runner.train_dataloader.persistent_workers = True
        runner.train_dataloader._DataLoader__initialized = True

        stage2 = [dict(type='RandomResize', scale=(1280, 1280))]

        runner.epoch = 284  # epoch < switch_epoch
        hook = PipelineSwitchHook(switch_epoch=285, switch_pipeline=stage2)
        hook.before_train_epoch(runner)
        self.assertFalse(hook._restart_dataloader)
        self.assertTrue(runner.train_dataloader._DataLoader__initialized)

        runner.epoch = 285  # epoch == switch_epoch
        hook.before_train_epoch(runner)
        self.assertTrue(hook._restart_dataloader)
        self.assertFalse(runner.train_dataloader._DataLoader__initialized)
        self.assertTrue(
            len(runner.train_dataloader.dataset.pipeline.transforms) == 1)

        runner.epoch = 286  # epoch > switch_epoch
        hook.before_train_epoch(runner)
        self.assertTrue(runner.train_dataloader._DataLoader__initialized)
        self.assertTrue(
            len(runner.train_dataloader.dataset.pipeline.transforms) == 1)

    def test_persistent_workers_off(self):
        runner = Mock()
        runner.model = Mock()
        runner.train_dataloader = Mock()
        runner.train_dataloader.persistent_workers = False
        runner.train_dataloader._DataLoader__initialized = True

        stage2 = [dict(type='RandomResize', scale=(1280, 1280))]

        runner.epoch = 284  # epoch < switch_epoch
        hook = PipelineSwitchHook(switch_epoch=285, switch_pipeline=stage2)
        hook.before_train_epoch(runner)
        self.assertFalse(hook._restart_dataloader)
        self.assertTrue(runner.train_dataloader._DataLoader__initialized)

        runner.epoch = 285  # epoch == switch_epoch
        hook.before_train_epoch(runner)
        self.assertFalse(hook._restart_dataloader)
        self.assertTrue(runner.train_dataloader._DataLoader__initialized)
        self.assertTrue(
            len(runner.train_dataloader.dataset.pipeline.transforms) == 1)

        runner.epoch = 286  # epoch > switch_epoch
        hook.before_train_epoch(runner)
        self.assertTrue(runner.train_dataloader._DataLoader__initialized)
        self.assertTrue(
            len(runner.train_dataloader.dataset.pipeline.transforms) == 1)

    def test_initialize_after_switching(self):
        # This simulates the resumption after the switching.
        runner = Mock()
        runner.model = Mock()
        runner.model.module = Mock()
        runner.train_dataloader = Mock()
        runner.train_dataloader.persistent_workers = True
        runner.train_dataloader._DataLoader__initialized = True

        stage2 = [dict(type='RandomResize', scale=(1280, 1280))]

        runner.epoch = 286  # epoch > switch_epoch
        hook = PipelineSwitchHook(switch_epoch=285, switch_pipeline=stage2)
        hook.before_train_epoch(runner)
        self.assertTrue(hook._restart_dataloader)
        self.assertFalse(runner.train_dataloader._DataLoader__initialized)
        self.assertTrue(
            len(runner.train_dataloader.dataset.pipeline.transforms) == 1)