# Copyright (c) OpenMMLab. All rights reserved.

from unittest import TestCase
from unittest.mock import patch

import numpy as np
from mmengine.dataset import DefaultSampler
from torch.utils.data import Dataset

from mmdet.datasets.samplers import AspectRatioBatchSampler


class DummyDataset(Dataset):

    def __init__(self, length):
        self.length = length
        self.shapes = np.random.random((length, 2))

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.shapes[idx]

    def get_data_info(self, idx):
        return dict(width=self.shapes[idx][0], height=self.shapes[idx][1])


class TestAspectRatioBatchSampler(TestCase):

    @patch('mmengine.dist.get_dist_info', return_value=(0, 1))
    def setUp(self, mock):
        self.length = 100
        self.dataset = DummyDataset(self.length)
        self.sampler = DefaultSampler(self.dataset, shuffle=False)

    def test_invalid_inputs(self):
        with self.assertRaisesRegex(
                ValueError, 'batch_size should be a positive integer value'):
            AspectRatioBatchSampler(self.sampler, batch_size=-1)

        with self.assertRaisesRegex(
                TypeError, 'sampler should be an instance of ``Sampler``'):
            AspectRatioBatchSampler(None, batch_size=1)

    def test_divisible_batch(self):
        batch_size = 5
        batch_sampler = AspectRatioBatchSampler(
            self.sampler, batch_size=batch_size, drop_last=True)
        self.assertEqual(len(batch_sampler), self.length // batch_size)
        for batch_idxs in batch_sampler:
            self.assertEqual(len(batch_idxs), batch_size)
            batch = [self.dataset[idx] for idx in batch_idxs]
            flag = batch[0][0] < batch[0][1]
            for i in range(1, batch_size):
                self.assertEqual(batch[i][0] < batch[i][1], flag)

    def test_indivisible_batch(self):
        batch_size = 7
        batch_sampler = AspectRatioBatchSampler(
            self.sampler, batch_size=batch_size, drop_last=False)
        all_batch_idxs = list(batch_sampler)
        self.assertEqual(
            len(batch_sampler), (self.length + batch_size - 1) // batch_size)
        self.assertEqual(
            len(all_batch_idxs), (self.length + batch_size - 1) // batch_size)

        batch_sampler = AspectRatioBatchSampler(
            self.sampler, batch_size=batch_size, drop_last=True)
        all_batch_idxs = list(batch_sampler)
        self.assertEqual(len(batch_sampler), self.length // batch_size)
        self.assertEqual(len(all_batch_idxs), self.length // batch_size)

        # the last batch may not have the same aspect ratio
        for batch_idxs in all_batch_idxs[:-1]:
            self.assertEqual(len(batch_idxs), batch_size)
            batch = [self.dataset[idx] for idx in batch_idxs]
            flag = batch[0][0] < batch[0][1]
            for i in range(1, batch_size):
                self.assertEqual(batch[i][0] < batch[i][1], flag)