test_local_visualizer.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import os
  2. from unittest import TestCase
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from mmengine.structures import InstanceData, PixelData
  7. from mmdet.evaluation import INSTANCE_OFFSET
  8. from mmdet.structures import DetDataSample
  9. from mmdet.visualization import DetLocalVisualizer, TrackLocalVisualizer
  10. def _rand_bboxes(num_boxes, h, w):
  11. cx, cy, bw, bh = torch.rand(num_boxes, 4).T
  12. tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w)
  13. tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h)
  14. br_x = ((cx * w) + (w * bw / 2)).clamp(0, w)
  15. br_y = ((cy * h) + (h * bh / 2)).clamp(0, h)
  16. bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=0).T
  17. return bboxes
  18. def _create_panoptic_data(num_boxes, h, w):
  19. sem_seg = np.zeros((h, w), dtype=np.int64) + 2
  20. bboxes = _rand_bboxes(num_boxes, h, w).int()
  21. labels = torch.randint(2, (num_boxes, ))
  22. for i in range(num_boxes):
  23. x, y, w, h = bboxes[i]
  24. sem_seg[y:y + h, x:x + w] = (i + 1) * INSTANCE_OFFSET + labels[i]
  25. return sem_seg[None]
  26. class TestDetLocalVisualizer(TestCase):
  27. def test_add_datasample(self):
  28. h = 12
  29. w = 10
  30. num_class = 3
  31. num_bboxes = 5
  32. out_file = 'out_file.jpg'
  33. image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
  34. # test gt_instances
  35. gt_instances = InstanceData()
  36. gt_instances.bboxes = _rand_bboxes(num_bboxes, h, w)
  37. gt_instances.labels = torch.randint(0, num_class, (num_bboxes, ))
  38. det_data_sample = DetDataSample()
  39. det_data_sample.gt_instances = gt_instances
  40. det_local_visualizer = DetLocalVisualizer()
  41. det_local_visualizer.add_datasample(
  42. 'image', image, det_data_sample, draw_pred=False)
  43. # test out_file
  44. det_local_visualizer.add_datasample(
  45. 'image',
  46. image,
  47. det_data_sample,
  48. draw_pred=False,
  49. out_file=out_file)
  50. assert os.path.exists(out_file)
  51. drawn_img = cv2.imread(out_file)
  52. assert drawn_img.shape == (h, w, 3)
  53. os.remove(out_file)
  54. # test gt_instances and pred_instances
  55. pred_instances = InstanceData()
  56. pred_instances.bboxes = _rand_bboxes(num_bboxes, h, w)
  57. pred_instances.labels = torch.randint(0, num_class, (num_bboxes, ))
  58. pred_instances.scores = torch.rand((num_bboxes, ))
  59. det_data_sample.pred_instances = pred_instances
  60. det_local_visualizer.add_datasample(
  61. 'image', image, det_data_sample, out_file=out_file)
  62. self._assert_image_and_shape(out_file, (h, w * 2, 3))
  63. det_local_visualizer.add_datasample(
  64. 'image', image, det_data_sample, draw_gt=False, out_file=out_file)
  65. self._assert_image_and_shape(out_file, (h, w, 3))
  66. det_local_visualizer.add_datasample(
  67. 'image',
  68. image,
  69. det_data_sample,
  70. draw_pred=False,
  71. out_file=out_file)
  72. self._assert_image_and_shape(out_file, (h, w, 3))
  73. # test gt_panoptic_seg and pred_panoptic_seg
  74. det_local_visualizer.dataset_meta = dict(classes=('1', '2'))
  75. gt_sem_seg = _create_panoptic_data(num_bboxes, h, w)
  76. panoptic_seg = PixelData(sem_seg=gt_sem_seg)
  77. det_data_sample = DetDataSample()
  78. det_data_sample.gt_panoptic_seg = panoptic_seg
  79. pred_sem_seg = _create_panoptic_data(num_bboxes, h, w)
  80. panoptic_seg = PixelData(sem_seg=pred_sem_seg)
  81. det_data_sample.pred_panoptic_seg = panoptic_seg
  82. det_local_visualizer.add_datasample(
  83. 'image', image, det_data_sample, out_file=out_file)
  84. self._assert_image_and_shape(out_file, (h, w * 2, 3))
  85. # class information must be provided
  86. det_local_visualizer.dataset_meta = {}
  87. with self.assertRaises(AssertionError):
  88. det_local_visualizer.add_datasample(
  89. 'image', image, det_data_sample, out_file=out_file)
  90. def _assert_image_and_shape(self, out_file, out_shape):
  91. assert os.path.exists(out_file)
  92. drawn_img = cv2.imread(out_file)
  93. assert drawn_img.shape == out_shape
  94. os.remove(out_file)
  95. class TestTrackLocalVisualizer(TestCase):
  96. @staticmethod
  97. def _get_gt_instances():
  98. bboxes = np.array([[912, 484, 1009, 593], [1338, 418, 1505, 797]])
  99. masks = np.zeros((2, 1080, 1920), dtype=np.bool_)
  100. for i, bbox in enumerate(bboxes):
  101. masks[i, bbox[1]:bbox[3], bbox[0]:bbox[2]] = True
  102. instances_data = dict(
  103. bboxes=torch.tensor(bboxes),
  104. masks=masks,
  105. instances_id=torch.tensor([1, 2]),
  106. labels=torch.tensor([0, 1]))
  107. instances = InstanceData(**instances_data)
  108. return instances
  109. @staticmethod
  110. def _get_pred_instances():
  111. instances_data = dict(
  112. bboxes=torch.tensor([[900, 500, 1000, 600], [1300, 400, 1500,
  113. 800]]),
  114. instances_id=torch.tensor([1, 2]),
  115. labels=torch.tensor([0, 1]),
  116. scores=torch.tensor([0.955, 0.876]))
  117. instances = InstanceData(**instances_data)
  118. return instances
  119. @staticmethod
  120. def _assert_image_and_shape(out_file, out_shape):
  121. assert os.path.exists(out_file)
  122. drawn_img = cv2.imread(out_file)
  123. assert drawn_img.shape == out_shape
  124. os.remove(out_file)
  125. def test_add_datasample(self):
  126. out_file = 'out_file.jpg'
  127. h, w = 1080, 1920
  128. image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
  129. gt_instances = self._get_gt_instances()
  130. pred_instances = self._get_pred_instances()
  131. image_data_sample = DetDataSample()
  132. image_data_sample.gt_instances = gt_instances
  133. image_data_sample.pred_track_instances = pred_instances
  134. track_local_visualizer = TrackLocalVisualizer(alpha=0.2)
  135. track_local_visualizer.dataset_meta = dict(
  136. classes=['pedestrian', 'vehicle'])
  137. # test gt_instances
  138. track_local_visualizer.add_datasample('image', image,
  139. image_data_sample, None)
  140. # test out_file
  141. track_local_visualizer.add_datasample(
  142. 'image', image, image_data_sample, None, out_file=out_file)
  143. self._assert_image_and_shape(out_file, (h, w, 3))
  144. # test gt_instances and pred_instances
  145. track_local_visualizer.add_datasample(
  146. 'image', image, image_data_sample, out_file=out_file)
  147. self._assert_image_and_shape(out_file, (h, 2 * w, 3))
  148. track_local_visualizer.add_datasample(
  149. 'image',
  150. image,
  151. image_data_sample,
  152. draw_gt=False,
  153. out_file=out_file)
  154. self._assert_image_and_shape(out_file, (h, w, 3))
  155. track_local_visualizer.add_datasample(
  156. 'image',
  157. image,
  158. image_data_sample,
  159. draw_pred=False,
  160. out_file=out_file)
  161. self._assert_image_and_shape(out_file, (h, w, 3))