test_track_img_sampler.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from collections.abc import Iterable
  3. from copy import deepcopy
  4. from unittest import TestCase
  5. from mmengine.dataset import ClassBalancedDataset, ConcatDataset
  6. from mmdet.datasets import MOTChallengeDataset, TrackImgSampler
  7. class TestTrackImgSampler(TestCase):
  8. def test_iter_base_video_dataset(self):
  9. # train mode
  10. dataset = MOTChallengeDataset(
  11. data_prefix=dict(img_path='imgs'),
  12. ann_file='tests/data/mot_sample.json',
  13. metainfo=dict(classes=('pedestrian')),
  14. filter_cfg=dict(filter_empty_gt=True, min_size=32),
  15. test_mode=False,
  16. pipeline=[])
  17. video_sampler = TrackImgSampler(dataset)
  18. assert len(video_sampler) == 5
  19. iterator = iter(video_sampler)
  20. assert isinstance(iterator, Iterable)
  21. for index in iterator:
  22. assert isinstance(index, tuple)
  23. video_index, frame_index = index
  24. assert video_index < 2
  25. if video_index == 0:
  26. assert frame_index >= 0 and frame_index < 3
  27. else:
  28. assert frame_index >= 0 and frame_index < 2
  29. # test mode
  30. dataset = MOTChallengeDataset(
  31. data_prefix=dict(img_path='imgs'),
  32. ann_file='tests/data/mot_sample.json',
  33. metainfo=dict(classes=('pedestrian')),
  34. filter_cfg=dict(filter_empty_gt=True, min_size=32),
  35. test_mode=True,
  36. pipeline=[])
  37. video_sampler = TrackImgSampler(dataset)
  38. assert len(video_sampler) == 5
  39. assert len(video_sampler.indices) == 1
  40. def test_iter_concat_dataset(self):
  41. single_dataset = MOTChallengeDataset(
  42. data_prefix=dict(img_path='imgs'),
  43. ann_file='tests/data/mot_sample.json',
  44. metainfo=dict(classes=('pedestrian')),
  45. filter_cfg=dict(filter_empty_gt=True, min_size=32),
  46. test_mode=False,
  47. pipeline=[])
  48. dataset = ConcatDataset([single_dataset, deepcopy(single_dataset)])
  49. video_sampler = TrackImgSampler(dataset)
  50. assert len(video_sampler) == 10
  51. iterator = iter(video_sampler)
  52. assert isinstance(iterator, Iterable)
  53. for index in iterator:
  54. assert isinstance(index, tuple)
  55. video_index, frame_index = index
  56. assert video_index < 4
  57. if video_index == 0:
  58. assert frame_index >= 0 and frame_index < 3
  59. elif video_index == 3:
  60. assert frame_index >= 0 and frame_index < 2
  61. def test_iter_class_balanced_dataset(self):
  62. single_dataset = MOTChallengeDataset(
  63. data_prefix=dict(img_path='imgs'),
  64. ann_file='tests/data/mot_sample.json',
  65. metainfo=dict(classes=('pedestrian', 'person_on_vehicle')),
  66. filter_cfg=dict(filter_empty_gt=True, min_size=32),
  67. visibility_thr=0.1,
  68. test_mode=False,
  69. pipeline=[])
  70. dataset = ClassBalancedDataset(single_dataset, oversample_thr=0.6)
  71. video_sampler = TrackImgSampler(dataset)
  72. assert len(video_sampler) == 8
  73. iterator = iter(video_sampler)
  74. assert isinstance(iterator, Iterable)
  75. for index in iterator:
  76. assert isinstance(index, tuple)
  77. video_index, frame_index = index
  78. assert video_index < 3
  79. if video_index == 0 or video_index == 2:
  80. assert frame_index >= 0 and frame_index < 3
  81. else:
  82. assert frame_index >= 0 and frame_index < 2