import unittest

import numpy as np

from mmdet.datasets.transforms import BaseFrameSample, UniformRefFrameSample


class TestFrameSample(unittest.TestCase):

    def setUp(self):
        """Setup the model and optimizer which are used in every test method.

        TestCase calls functions in this order: setUp() -> testMethod()
        -> tearDown() -> cleanUp()
        """
        self.H, self.W = 5, 8
        self.img = np.zeros((self.H, self.W, 3))
        self.gt_bboxes = np.zeros((2, 4))
        self.gt_bboxes_labels = [
            np.zeros((2, )),
            np.zeros((2, )) + 1,
            np.zeros((2, )) - 1
        ]
        self.gt_instances_id = [
            np.ones((2, ), dtype=np.int32),
            np.ones((2, ), dtype=np.int32) - 1,
            np.ones((2, ), dtype=np.int32) + 1
        ]
        self.frame_id = [0, 1, 2]
        self.scale_factor = [1.0, 1.5, 2.0]
        self.flip = [False] * 3
        self.ori_shape = [(self.H, self.W)] * 3
        self.img_id = [0, 1, 2]

        self.video_infos = dict(video_id=0, video_length=10, key_frame_id=4)
        self.video_infos['images'] = []
        self.info_keys = [
            'video_id', 'video_length', 'img', 'gt_bboxes', 'gt_bboxes_labels',
            'gt_instances_id', 'img_id', 'frame_id'
        ]
        for i in range(10):
            frame_info = dict(
                img=np.zeros((self.H, self.W, 3)) + i,
                gt_bboxes=np.zeros((2, 4)) + i,
                gt_bboxes_labels=np.zeros((2, )) + i,
                gt_instances_id=np.zeros((2, ), dtype=np.int32) + i,
                ori_shape=(self.H + i, self.W + i),
                frame_id=i,
                img_id=i)
            self.video_infos['images'].append(frame_info)

    def test_base_frame_sample(self):
        sampler = BaseFrameSample()
        results = sampler(self.video_infos)
        assert isinstance(results, dict)
        for key in self.info_keys:
            assert key in results
            assert len(results[key]) == 1
            if key == 'frame_id':
                assert results[key] == [4]

        key_frame_id = self.video_infos['key_frame_id']
        assert (results['img'][0] == np.zeros(
            (self.H, self.W, 3)) + key_frame_id).all()
        assert (results['gt_bboxes'][0] == np.zeros(
            (2, 4)) + key_frame_id).all()
        assert (results['gt_bboxes_labels'][0] == np.zeros(
            (2, )) + key_frame_id).all()
        assert (results['gt_instances_id'][0] == np.zeros(
            (2, )) + key_frame_id).all()
        assert results['ori_shape'][0] == (self.H + key_frame_id,
                                           self.W + key_frame_id)
        assert results['img_id'][0] == key_frame_id

    def test_uniform_ref_frame_sample(self):
        sampler = UniformRefFrameSample(
            num_ref_imgs=2, frame_range=[-1, 1], filter_key_img=True)
        results = sampler(self.video_infos)
        assert isinstance(results, dict)
        for key in self.info_keys:
            assert key in results
            assert len(results[key]) == 3
            if key == 'frame_id':
                assert results[key] == [3, 4, 5]

        key_frame_id = self.video_infos['key_frame_id']
        assert (results['img'][1] == np.zeros(
            (self.H, self.W, 3)) + key_frame_id).all()
        assert (results['gt_bboxes'][1] == np.zeros(
            (2, 4)) + key_frame_id).all()
        assert (results['gt_bboxes_labels'][1] == np.zeros(
            (2, )) + key_frame_id).all()
        assert (results['gt_instances_id'][1] == np.zeros(
            (2, )) + key_frame_id).all()
        assert results['ori_shape'][1] == (self.H + key_frame_id,
                                           self.W + key_frame_id)
        assert results['img_id'][1] == key_frame_id

        # test the filter_key_img and the correctness of returned frame index
        sampler = UniformRefFrameSample(
            num_ref_imgs=2, frame_range=[0, 1], filter_key_img=False)
        results = sampler(self.video_infos)
        assert 4 in results['img_id'] and results['img_id'].count(4) == 2
        assert 5 in results['img_id'] and results['img_id'].count(5) == 1
        assert results['key_frame_flags'] == [True, False, False]

    def test_repr(self):
        transform = BaseFrameSample()
        self.assertEqual(
            repr(transform),
            "BaseFrameSample(collect_video_keys=['video_id', 'video_length'])")

        transform = UniformRefFrameSample(
            num_ref_imgs=2, frame_range=10, filter_key_img=True)
        self.assertEqual(
            repr(transform),
            ('UniformRefFrameSample(num_ref_imgs=2, '
             'frame_range=[-10, 10], filter_key_img=True, '
             "collect_video_keys=['video_id', 'video_length'])"))