# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock, patch

from mmdet.engine.hooks import YOLOXModeSwitchHook


class TestYOLOXModeSwitchHook(TestCase):

    @patch('mmdet.engine.hooks.yolox_mode_switch_hook.is_model_wrapper')
    def test_is_model_wrapper_and_persistent_workers_on(
            self, mock_is_model_wrapper):
        mock_is_model_wrapper.return_value = True
        runner = Mock()
        runner.model = Mock()
        runner.model.module = Mock()
        runner.model.module.detector.bbox_head.use_l1 = False
        runner.train_dataloader = Mock()
        runner.train_dataloader.persistent_workers = True
        runner.train_dataloader._DataLoader__initialized = True
        runner.epoch = 284
        runner.max_epochs = 300

        hook = YOLOXModeSwitchHook(num_last_epochs=15)
        hook.before_train_epoch(runner)
        self.assertTrue(hook._restart_dataloader)
        self.assertTrue(runner.model.module.detector.bbox_head.use_l1)
        self.assertFalse(runner.train_dataloader._DataLoader__initialized)

        runner.epoch = 285
        hook.before_train_epoch(runner)
        self.assertTrue(runner.train_dataloader._DataLoader__initialized)

    def test_not_model_wrapper_and_persistent_workers_off(self):
        runner = Mock()
        runner.model = Mock()
        runner.model.detector.bbox_head.use_l1 = False
        runner.train_dataloader = Mock()
        runner.train_dataloader.persistent_workers = False
        runner.train_dataloader._DataLoader__initialized = True
        runner.epoch = 284
        runner.max_epochs = 300

        hook = YOLOXModeSwitchHook(num_last_epochs=15)
        hook.before_train_epoch(runner)
        self.assertFalse(hook._restart_dataloader)
        self.assertTrue(runner.model.detector.bbox_head.use_l1)
        self.assertTrue(runner.train_dataloader._DataLoader__initialized)

        runner.epoch = 285
        hook.before_train_epoch(runner)
        self.assertFalse(hook._restart_dataloader)
        self.assertTrue(runner.train_dataloader._DataLoader__initialized)

    @patch('mmdet.engine.hooks.yolox_mode_switch_hook.is_model_wrapper')
    def test_initialize_after_switching(self, mock_is_model_wrapper):
        # This simulates the resumption after the switching.
        mock_is_model_wrapper.return_value = True
        runner = Mock()
        runner.model = Mock()
        runner.model.module = Mock()
        runner.model.module.bbox_head.use_l1 = False
        runner.train_dataloader = Mock()
        runner.train_dataloader.persistent_workers = True
        runner.train_dataloader._DataLoader__initialized = True
        runner.epoch = 285
        runner.max_epochs = 300

        # epoch + 1 > max_epochs - num_last_epochs .
        hook = YOLOXModeSwitchHook(num_last_epochs=15)
        hook.before_train_epoch(runner)
        self.assertTrue(hook._restart_dataloader)
        self.assertTrue(runner.model.module.detector.bbox_head.use_l1)
        self.assertFalse(runner.train_dataloader._DataLoader__initialized)