test_youtube_vis_metric.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os
  3. import tempfile
  4. from unittest import TestCase
  5. import numpy as np
  6. import pycocotools.mask as mask_util
  7. import torch
  8. from mmengine.registry import init_default_scope
  9. from mmengine.structures import BaseDataElement, InstanceData
  10. from mmdet.registry import METRICS
  11. from mmdet.structures import DetDataSample, TrackDataSample
  12. class TestYouTubeVISMetric(TestCase):
  13. @classmethod
  14. def setUpClass(cls):
  15. init_default_scope('mmdet')
  16. def setUp(self):
  17. self.tmp_dir = tempfile.TemporaryDirectory()
  18. def tearDown(self):
  19. self.tmp_dir.cleanup()
  20. def _create_dummy_results(self, track_id):
  21. bboxes = np.array([[100, 100, 150, 150]])
  22. scores = np.array([1.0])
  23. labels = np.array([0])
  24. instance_id = np.array([track_id])
  25. dummy_mask = np.zeros((1, 720, 1280), dtype=np.uint8)
  26. dummy_mask[:, 100:150, 100:150] = 1
  27. return dict(
  28. bboxes=torch.from_numpy(bboxes),
  29. scores=torch.from_numpy(scores),
  30. labels=torch.from_numpy(labels),
  31. instances_id=torch.from_numpy(instance_id),
  32. masks=torch.from_numpy(dummy_mask))
  33. def test_format_only(self):
  34. outfile_prefix = f'{self.tmp_dir.name}/result'
  35. vis_metric = METRICS.build(
  36. dict(
  37. type='YouTubeVISMetric',
  38. format_only=True,
  39. outfile_prefix=outfile_prefix,
  40. ))
  41. dummy_pred = self._create_dummy_results(track_id=0)
  42. dummy_mask = np.zeros((720, 1280), order='F', dtype=np.uint8)
  43. dummy_mask[100:150, 100:150] = 1
  44. rle_mask = mask_util.encode(dummy_mask)
  45. rle_mask['counts'] = rle_mask['counts'].decode('utf-8')
  46. instances = [{
  47. 'bbox_label': 0,
  48. 'bbox': [100, 100, 150, 150],
  49. 'ignore_flag': 0,
  50. 'instance_id': 1,
  51. 'mask': rle_mask,
  52. }]
  53. vis_metric.dataset_meta = dict(classes=['car', 'train'])
  54. data_batch = dict(inputs=None, data_samples=None)
  55. gt_insatnce = InstanceData(**dummy_pred)
  56. img_data_sample = DetDataSample()
  57. img_data_sample.pred_track_instances = gt_insatnce
  58. img_data_sample.set_metainfo(
  59. dict(
  60. img_id=0,
  61. video_id=1,
  62. ori_video_length=1,
  63. ori_shape=(720, 1280),
  64. instances=instances))
  65. track_data_sample = TrackDataSample()
  66. track_data_sample.video_data_samples = [img_data_sample]
  67. predictions = []
  68. if isinstance(track_data_sample, BaseDataElement):
  69. predictions.append(track_data_sample.to_dict())
  70. vis_metric.process(data_batch, predictions)
  71. vis_metric.evaluate(size=1)
  72. assert os.path.exists(f'{outfile_prefix}.json')
  73. assert os.path.exists(f'{outfile_prefix}.submission_file.zip')
  74. def test_evaluate(self):
  75. """Test using the metric in the same way as Evaluator."""
  76. dummy_pred_1 = self._create_dummy_results(track_id=1)
  77. dummy_pred_2 = self._create_dummy_results(track_id=1)
  78. dummy_pred_3 = self._create_dummy_results(track_id=2)
  79. dummy_mask = np.zeros((720, 1280), order='F', dtype=np.uint8)
  80. dummy_mask[100:150, 100:150] = 1
  81. rle_mask = mask_util.encode(dummy_mask)
  82. rle_mask['counts'] = rle_mask['counts'].decode('utf-8')
  83. instances_1 = [{
  84. 'bbox_label': 0,
  85. 'bbox': [100, 100, 150, 150],
  86. 'ignore_flag': 0,
  87. 'instance_id': 1,
  88. 'mask': rle_mask,
  89. }]
  90. instances_2 = [{
  91. 'bbox_label': 0,
  92. 'bbox': [100, 100, 150, 150],
  93. 'ignore_flag': 0,
  94. 'instance_id': 2,
  95. 'mask': rle_mask,
  96. }]
  97. vis_metric = METRICS.build(
  98. dict(
  99. type='YouTubeVISMetric',
  100. outfile_prefix=f'{self.tmp_dir.name}/test',
  101. ))
  102. vis_metric.dataset_meta = dict(classes=['car', 'train'])
  103. data_batch = dict(inputs=None, data_samples=None)
  104. gt_insatnce = InstanceData(**dummy_pred_1)
  105. img_data_sample = DetDataSample()
  106. img_data_sample.pred_track_instances = gt_insatnce
  107. img_data_sample.set_metainfo(
  108. dict(
  109. img_id=1,
  110. video_id=1,
  111. ori_video_length=2,
  112. ori_shape=(720, 1280),
  113. instances=instances_1))
  114. gt_insatnce_2 = InstanceData(**dummy_pred_2)
  115. img_data_sample_2 = DetDataSample()
  116. img_data_sample_2.pred_track_instances = gt_insatnce_2
  117. img_data_sample_2.set_metainfo(
  118. dict(
  119. img_id=2,
  120. video_id=1,
  121. ori_video_length=2,
  122. ori_shape=(720, 1280),
  123. instances=instances_1))
  124. track_data_sample = TrackDataSample()
  125. track_data_sample.video_data_samples = [
  126. img_data_sample, img_data_sample_2
  127. ]
  128. predictions = []
  129. if isinstance(track_data_sample, BaseDataElement):
  130. predictions.append(track_data_sample.to_dict())
  131. vis_metric.process(data_batch, predictions)
  132. gt_insatnce = InstanceData(**dummy_pred_3)
  133. img_data_sample = DetDataSample()
  134. img_data_sample.pred_track_instances = gt_insatnce
  135. img_data_sample.set_metainfo(
  136. dict(
  137. img_id=3,
  138. video_id=2,
  139. ori_video_length=1,
  140. ori_shape=(720, 1280),
  141. instances=instances_2))
  142. track_data_sample = TrackDataSample()
  143. track_data_sample.video_data_samples = [img_data_sample]
  144. predictions = []
  145. if isinstance(track_data_sample, BaseDataElement):
  146. predictions.append(track_data_sample.to_dict())
  147. vis_metric.process(data_batch, predictions)
  148. eval_results = vis_metric.evaluate(size=3)
  149. target = {
  150. 'youtube_vis/segm_mAP': 1.0,
  151. 'youtube_vis/segm_mAP_50': 1.0,
  152. 'youtube_vis/segm_mAP_75': 1.0,
  153. 'youtube_vis/segm_mAP_s': 1.0,
  154. 'youtube_vis/segm_mAP_m': -1.0,
  155. 'youtube_vis/segm_mAP_l': -1.0,
  156. }
  157. self.assertDictEqual(eval_results, target)