# Copyright (c) OpenMMLab. All rights reserved.
import unittest

from mmdet.datasets import VOCDataset


class TestVOCDataset(unittest.TestCase):

    def test_voc2007_init(self):
        dataset = VOCDataset(
            data_root='tests/data/VOCdevkit/',
            ann_file='VOC2007/ImageSets/Main/trainval.txt',
            data_prefix=dict(sub_data_root='VOC2007/'),
            filter_cfg=dict(
                filter_empty_gt=True, min_size=32, bbox_min_size=32),
            pipeline=[])
        dataset.full_init()
        self.assertEqual(len(dataset), 1)

        data_list = dataset.load_data_list()
        self.assertEqual(len(data_list), 1)
        self.assertEqual(len(data_list[0]['instances']), 2)
        self.assertEqual(dataset.get_cat_ids(0), [11, 14])

    def test_voc2012_init(self):
        dataset = VOCDataset(
            data_root='tests/data/VOCdevkit/',
            ann_file='VOC2012/ImageSets/Main/trainval.txt',
            data_prefix=dict(sub_data_root='VOC2012/'),
            filter_cfg=dict(filter_empty_gt=True, min_size=32),
            pipeline=[])
        dataset.full_init()
        self.assertEqual(len(dataset), 1)

        data_list = dataset.load_data_list()
        self.assertEqual(len(data_list), 1)
        self.assertEqual(len(data_list[0]['instances']), 1)
        self.assertEqual(dataset.get_cat_ids(0), [18])