test_frame_sampling.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import unittest
  2. import numpy as np
  3. from mmdet.datasets.transforms import BaseFrameSample, UniformRefFrameSample
  4. class TestFrameSample(unittest.TestCase):
  5. def setUp(self):
  6. """Setup the model and optimizer which are used in every test method.
  7. TestCase calls functions in this order: setUp() -> testMethod()
  8. -> tearDown() -> cleanUp()
  9. """
  10. self.H, self.W = 5, 8
  11. self.img = np.zeros((self.H, self.W, 3))
  12. self.gt_bboxes = np.zeros((2, 4))
  13. self.gt_bboxes_labels = [
  14. np.zeros((2, )),
  15. np.zeros((2, )) + 1,
  16. np.zeros((2, )) - 1
  17. ]
  18. self.gt_instances_id = [
  19. np.ones((2, ), dtype=np.int32),
  20. np.ones((2, ), dtype=np.int32) - 1,
  21. np.ones((2, ), dtype=np.int32) + 1
  22. ]
  23. self.frame_id = [0, 1, 2]
  24. self.scale_factor = [1.0, 1.5, 2.0]
  25. self.flip = [False] * 3
  26. self.ori_shape = [(self.H, self.W)] * 3
  27. self.img_id = [0, 1, 2]
  28. self.video_infos = dict(video_id=0, video_length=10, key_frame_id=4)
  29. self.video_infos['images'] = []
  30. self.info_keys = [
  31. 'video_id', 'video_length', 'img', 'gt_bboxes', 'gt_bboxes_labels',
  32. 'gt_instances_id', 'img_id', 'frame_id'
  33. ]
  34. for i in range(10):
  35. frame_info = dict(
  36. img=np.zeros((self.H, self.W, 3)) + i,
  37. gt_bboxes=np.zeros((2, 4)) + i,
  38. gt_bboxes_labels=np.zeros((2, )) + i,
  39. gt_instances_id=np.zeros((2, ), dtype=np.int32) + i,
  40. ori_shape=(self.H + i, self.W + i),
  41. frame_id=i,
  42. img_id=i)
  43. self.video_infos['images'].append(frame_info)
  44. def test_base_frame_sample(self):
  45. sampler = BaseFrameSample()
  46. results = sampler(self.video_infos)
  47. assert isinstance(results, dict)
  48. for key in self.info_keys:
  49. assert key in results
  50. assert len(results[key]) == 1
  51. if key == 'frame_id':
  52. assert results[key] == [4]
  53. key_frame_id = self.video_infos['key_frame_id']
  54. assert (results['img'][0] == np.zeros(
  55. (self.H, self.W, 3)) + key_frame_id).all()
  56. assert (results['gt_bboxes'][0] == np.zeros(
  57. (2, 4)) + key_frame_id).all()
  58. assert (results['gt_bboxes_labels'][0] == np.zeros(
  59. (2, )) + key_frame_id).all()
  60. assert (results['gt_instances_id'][0] == np.zeros(
  61. (2, )) + key_frame_id).all()
  62. assert results['ori_shape'][0] == (self.H + key_frame_id,
  63. self.W + key_frame_id)
  64. assert results['img_id'][0] == key_frame_id
  65. def test_uniform_ref_frame_sample(self):
  66. sampler = UniformRefFrameSample(
  67. num_ref_imgs=2, frame_range=[-1, 1], filter_key_img=True)
  68. results = sampler(self.video_infos)
  69. assert isinstance(results, dict)
  70. for key in self.info_keys:
  71. assert key in results
  72. assert len(results[key]) == 3
  73. if key == 'frame_id':
  74. assert results[key] == [3, 4, 5]
  75. key_frame_id = self.video_infos['key_frame_id']
  76. assert (results['img'][1] == np.zeros(
  77. (self.H, self.W, 3)) + key_frame_id).all()
  78. assert (results['gt_bboxes'][1] == np.zeros(
  79. (2, 4)) + key_frame_id).all()
  80. assert (results['gt_bboxes_labels'][1] == np.zeros(
  81. (2, )) + key_frame_id).all()
  82. assert (results['gt_instances_id'][1] == np.zeros(
  83. (2, )) + key_frame_id).all()
  84. assert results['ori_shape'][1] == (self.H + key_frame_id,
  85. self.W + key_frame_id)
  86. assert results['img_id'][1] == key_frame_id
  87. # test the filter_key_img and the correctness of returned frame index
  88. sampler = UniformRefFrameSample(
  89. num_ref_imgs=2, frame_range=[0, 1], filter_key_img=False)
  90. results = sampler(self.video_infos)
  91. assert 4 in results['img_id'] and results['img_id'].count(4) == 2
  92. assert 5 in results['img_id'] and results['img_id'].count(5) == 1
  93. assert results['key_frame_flags'] == [True, False, False]
  94. def test_repr(self):
  95. transform = BaseFrameSample()
  96. self.assertEqual(
  97. repr(transform),
  98. "BaseFrameSample(collect_video_keys=['video_id', 'video_length'])")
  99. transform = UniformRefFrameSample(
  100. num_ref_imgs=2, frame_range=10, filter_key_img=True)
  101. self.assertEqual(
  102. repr(transform),
  103. ('UniformRefFrameSample(num_ref_imgs=2, '
  104. 'frame_range=[-10, 10], filter_key_img=True, '
  105. "collect_video_keys=['video_id', 'video_length'])"))