test_visualization_hook.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import shutil
  4. import time
  5. from unittest import TestCase
  6. from unittest.mock import Mock
  7. import torch
  8. from mmengine.structures import InstanceData
  9. from mmdet.engine.hooks import DetVisualizationHook, TrackVisualizationHook
  10. from mmdet.structures import DetDataSample, TrackDataSample
  11. from mmdet.visualization import DetLocalVisualizer, TrackLocalVisualizer
  12. def _rand_bboxes(num_boxes, h, w):
  13. cx, cy, bw, bh = torch.rand(num_boxes, 4).T
  14. tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w)
  15. tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h)
  16. br_x = ((cx * w) + (w * bw / 2)).clamp(0, w)
  17. br_y = ((cy * h) + (h * bh / 2)).clamp(0, h)
  18. bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=0).T
  19. return bboxes
  20. class TestVisualizationHook(TestCase):
  21. def setUp(self) -> None:
  22. DetLocalVisualizer.get_instance('current_visualizer')
  23. pred_instances = InstanceData()
  24. pred_instances.bboxes = _rand_bboxes(5, 10, 12)
  25. pred_instances.labels = torch.randint(0, 2, (5, ))
  26. pred_instances.scores = torch.rand((5, ))
  27. pred_det_data_sample = DetDataSample()
  28. pred_det_data_sample.set_metainfo({
  29. 'img_path':
  30. osp.join(osp.dirname(__file__), '../../data/color.jpg')
  31. })
  32. pred_det_data_sample.pred_instances = pred_instances
  33. self.outputs = [pred_det_data_sample] * 2
  34. def test_after_val_iter(self):
  35. runner = Mock()
  36. runner.iter = 1
  37. hook = DetVisualizationHook()
  38. hook.after_val_iter(runner, 1, {}, self.outputs)
  39. def test_after_test_iter(self):
  40. runner = Mock()
  41. runner.iter = 1
  42. hook = DetVisualizationHook(draw=True)
  43. hook.after_test_iter(runner, 1, {}, self.outputs)
  44. self.assertEqual(hook._test_index, 2)
  45. # test
  46. timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
  47. test_out_dir = timestamp + '1'
  48. runner.work_dir = timestamp
  49. runner.timestamp = '1'
  50. hook = DetVisualizationHook(draw=False, test_out_dir=test_out_dir)
  51. hook.after_test_iter(runner, 1, {}, self.outputs)
  52. self.assertTrue(not osp.exists(f'{timestamp}/1/{test_out_dir}'))
  53. hook = DetVisualizationHook(draw=True, test_out_dir=test_out_dir)
  54. hook.after_test_iter(runner, 1, {}, self.outputs)
  55. self.assertTrue(osp.exists(f'{timestamp}/1/{test_out_dir}'))
  56. shutil.rmtree(f'{timestamp}')
  57. class TestTrackVisualizationHook(TestCase):
  58. def setUp(self) -> None:
  59. TrackLocalVisualizer.get_instance('visualizer')
  60. # pseudo data_batch
  61. self.data_batch = dict(data_samples=None, inputs=None)
  62. pred_instances_data = dict(
  63. bboxes=torch.tensor([[100, 100, 200, 200], [150, 150, 400, 200]]),
  64. instances_id=torch.tensor([1, 2]),
  65. labels=torch.tensor([0, 1]),
  66. scores=torch.tensor([0.955, 0.876]))
  67. pred_instances = InstanceData(**pred_instances_data)
  68. img_data_sample = DetDataSample()
  69. img_data_sample.pred_track_instances = pred_instances
  70. img_data_sample.gt_instances = pred_instances
  71. img_data_sample.set_metainfo(
  72. dict(
  73. img_path=osp.join(
  74. osp.dirname(__file__), '../../data/color.jpg'),
  75. scale_factor=(1.0, 1.0)))
  76. track_data_sample = TrackDataSample()
  77. track_data_sample.video_data_samples = [img_data_sample]
  78. track_data_sample.set_metainfo(dict(ori_length=1))
  79. self.outputs = [track_data_sample]
  80. def test_after_val_iter_image(self):
  81. runner = Mock()
  82. runner.iter = 1
  83. hook = TrackVisualizationHook(frame_interval=10, draw=True)
  84. hook.after_val_iter(runner, 9, self.data_batch, self.outputs)
  85. def test_after_test_iter(self):
  86. runner = Mock()
  87. runner.iter = 1
  88. hook = TrackVisualizationHook(frame_interval=10, draw=True)
  89. hook.after_val_iter(runner, 9, self.data_batch, self.outputs)
  90. # test test_out_dir
  91. timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
  92. test_out_dir = timestamp + '1'
  93. runner.work_dir = timestamp
  94. runner.timestamp = '1'
  95. hook = TrackVisualizationHook(
  96. frame_interval=10, draw=True, test_out_dir=test_out_dir)
  97. hook.after_test_iter(runner, 9, self.data_batch, self.outputs)
  98. self.assertTrue(osp.exists(f'{timestamp}/1/{test_out_dir}'))
  99. shutil.rmtree(f'{timestamp}')