1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from collections.abc import Iterable
- from copy import deepcopy
- from unittest import TestCase
- from mmengine.dataset import ClassBalancedDataset, ConcatDataset
- from mmdet.datasets import MOTChallengeDataset, TrackImgSampler
- class TestTrackImgSampler(TestCase):
- def test_iter_base_video_dataset(self):
- # train mode
- dataset = MOTChallengeDataset(
- data_prefix=dict(img_path='imgs'),
- ann_file='tests/data/mot_sample.json',
- metainfo=dict(classes=('pedestrian')),
- filter_cfg=dict(filter_empty_gt=True, min_size=32),
- test_mode=False,
- pipeline=[])
- video_sampler = TrackImgSampler(dataset)
- assert len(video_sampler) == 5
- iterator = iter(video_sampler)
- assert isinstance(iterator, Iterable)
- for index in iterator:
- assert isinstance(index, tuple)
- video_index, frame_index = index
- assert video_index < 2
- if video_index == 0:
- assert frame_index >= 0 and frame_index < 3
- else:
- assert frame_index >= 0 and frame_index < 2
- # test mode
- dataset = MOTChallengeDataset(
- data_prefix=dict(img_path='imgs'),
- ann_file='tests/data/mot_sample.json',
- metainfo=dict(classes=('pedestrian')),
- filter_cfg=dict(filter_empty_gt=True, min_size=32),
- test_mode=True,
- pipeline=[])
- video_sampler = TrackImgSampler(dataset)
- assert len(video_sampler) == 5
- assert len(video_sampler.indices) == 1
- def test_iter_concat_dataset(self):
- single_dataset = MOTChallengeDataset(
- data_prefix=dict(img_path='imgs'),
- ann_file='tests/data/mot_sample.json',
- metainfo=dict(classes=('pedestrian')),
- filter_cfg=dict(filter_empty_gt=True, min_size=32),
- test_mode=False,
- pipeline=[])
- dataset = ConcatDataset([single_dataset, deepcopy(single_dataset)])
- video_sampler = TrackImgSampler(dataset)
- assert len(video_sampler) == 10
- iterator = iter(video_sampler)
- assert isinstance(iterator, Iterable)
- for index in iterator:
- assert isinstance(index, tuple)
- video_index, frame_index = index
- assert video_index < 4
- if video_index == 0:
- assert frame_index >= 0 and frame_index < 3
- elif video_index == 3:
- assert frame_index >= 0 and frame_index < 2
- def test_iter_class_balanced_dataset(self):
- single_dataset = MOTChallengeDataset(
- data_prefix=dict(img_path='imgs'),
- ann_file='tests/data/mot_sample.json',
- metainfo=dict(classes=('pedestrian', 'person_on_vehicle')),
- filter_cfg=dict(filter_empty_gt=True, min_size=32),
- visibility_thr=0.1,
- test_mode=False,
- pipeline=[])
- dataset = ClassBalancedDataset(single_dataset, oversample_thr=0.6)
- video_sampler = TrackImgSampler(dataset)
- assert len(video_sampler) == 8
- iterator = iter(video_sampler)
- assert isinstance(iterator, Iterable)
- for index in iterator:
- assert isinstance(index, tuple)
- video_index, frame_index = index
- assert video_index < 3
- if video_index == 0 or video_index == 2:
- assert frame_index >= 0 and frame_index < 3
- else:
- assert frame_index >= 0 and frame_index < 2
|