test_mot_challenge_dataset.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import unittest
  3. from mmdet.datasets import MOTChallengeDataset
  4. class TestMOTChallengeDataset(unittest.TestCase):
  5. def test_mot_challenge_dataset(self):
  6. # test CocoDataset
  7. metainfo = dict(classes=('pedestrian'), task_name='new_task')
  8. dataset = MOTChallengeDataset(
  9. data_prefix=dict(img_path='imgs'),
  10. ann_file='tests/data/mot_sample.json',
  11. metainfo=metainfo,
  12. filter_cfg=dict(filter_empty_gt=True, min_size=32),
  13. pipeline=[],
  14. serialize_data=False,
  15. lazy_init=False)
  16. self.assertEqual(dataset.metainfo['classes'], ('pedestrian'))
  17. self.assertEqual(dataset.metainfo['task_name'], 'new_task')
  18. self.assertListEqual(dataset.get_cat_ids((0, 1)), [0, 0])
  19. self.assertListEqual(dataset.get_cat_ids(0), [0, 0, 0, 0, 0, 0])
  20. self.assertEqual(len(dataset), 2)
  21. self.assertEqual(dataset.num_all_imgs, 5)
  22. self.assertEqual(len(dataset[0]['images'][2]['instances']), 2)
  23. def test_mot_challenge_dataset_with_visibility(self):
  24. dataset = MOTChallengeDataset(
  25. data_prefix=dict(img_path='imgs'),
  26. ann_file='tests/data/mot_sample.json',
  27. metainfo=dict(classes=('pedestrian')),
  28. filter_cfg=dict(filter_empty_gt=True, min_size=32),
  29. visibility_thr=0.5,
  30. pipeline=[])
  31. self.assertEqual(dataset.num_all_imgs, 5)
  32. self.assertEqual(len(dataset[0]['images'][2]['instances']), 1)