test_det_data_sample.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from unittest import TestCase
  2. import numpy as np
  3. import pytest
  4. import torch
  5. from mmengine.structures import InstanceData, PixelData
  6. from mmdet.structures import DetDataSample
  7. def _equal(a, b):
  8. if isinstance(a, (torch.Tensor, np.ndarray)):
  9. return (a == b).all()
  10. else:
  11. return a == b
  12. class TestDetDataSample(TestCase):
  13. def test_init(self):
  14. meta_info = dict(
  15. img_size=[256, 256],
  16. scale_factor=np.array([1.5, 1.5]),
  17. img_shape=torch.rand(4))
  18. det_data_sample = DetDataSample(metainfo=meta_info)
  19. assert 'img_size' in det_data_sample
  20. assert det_data_sample.img_size == [256, 256]
  21. assert det_data_sample.get('img_size') == [256, 256]
  22. def test_setter(self):
  23. det_data_sample = DetDataSample()
  24. # test gt_instances
  25. gt_instances_data = dict(
  26. bboxes=torch.rand(4, 4),
  27. labels=torch.rand(4),
  28. masks=np.random.rand(4, 2, 2))
  29. gt_instances = InstanceData(**gt_instances_data)
  30. det_data_sample.gt_instances = gt_instances
  31. assert 'gt_instances' in det_data_sample
  32. assert _equal(det_data_sample.gt_instances.bboxes,
  33. gt_instances_data['bboxes'])
  34. assert _equal(det_data_sample.gt_instances.labels,
  35. gt_instances_data['labels'])
  36. assert _equal(det_data_sample.gt_instances.masks,
  37. gt_instances_data['masks'])
  38. # test pred_instances
  39. pred_instances_data = dict(
  40. bboxes=torch.rand(2, 4),
  41. labels=torch.rand(2),
  42. masks=np.random.rand(2, 2, 2))
  43. pred_instances = InstanceData(**pred_instances_data)
  44. det_data_sample.pred_instances = pred_instances
  45. assert 'pred_instances' in det_data_sample
  46. assert _equal(det_data_sample.pred_instances.bboxes,
  47. pred_instances_data['bboxes'])
  48. assert _equal(det_data_sample.pred_instances.labels,
  49. pred_instances_data['labels'])
  50. assert _equal(det_data_sample.pred_instances.masks,
  51. pred_instances_data['masks'])
  52. # test pred_track_instances
  53. pred_track_instances_data = dict(
  54. bboxes=torch.rand(2, 4),
  55. labels=torch.rand(2),
  56. masks=np.random.rand(2, 2, 2))
  57. pred_instances = InstanceData(**pred_track_instances_data)
  58. det_data_sample.pred_instances = pred_instances
  59. assert 'pred_instances' in det_data_sample
  60. assert _equal(det_data_sample.pred_instances.bboxes,
  61. pred_track_instances_data['bboxes'])
  62. assert _equal(det_data_sample.pred_instances.labels,
  63. pred_track_instances_data['labels'])
  64. assert _equal(det_data_sample.pred_instances.masks,
  65. pred_track_instances_data['masks'])
  66. # test proposals
  67. proposals_data = dict(bboxes=torch.rand(4, 4), labels=torch.rand(4))
  68. proposals = InstanceData(**proposals_data)
  69. det_data_sample.proposals = proposals
  70. assert 'proposals' in det_data_sample
  71. assert _equal(det_data_sample.proposals.bboxes,
  72. proposals_data['bboxes'])
  73. assert _equal(det_data_sample.proposals.labels,
  74. proposals_data['labels'])
  75. # test ignored_instances
  76. ignored_instances_data = dict(
  77. bboxes=torch.rand(4, 4), labels=torch.rand(4))
  78. ignored_instances = InstanceData(**ignored_instances_data)
  79. det_data_sample.ignored_instances = ignored_instances
  80. assert 'ignored_instances' in det_data_sample
  81. assert _equal(det_data_sample.ignored_instances.bboxes,
  82. ignored_instances_data['bboxes'])
  83. assert _equal(det_data_sample.ignored_instances.labels,
  84. ignored_instances_data['labels'])
  85. # test gt_panoptic_seg
  86. gt_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4))
  87. gt_panoptic_seg = PixelData(**gt_panoptic_seg_data)
  88. det_data_sample.gt_panoptic_seg = gt_panoptic_seg
  89. assert 'gt_panoptic_seg' in det_data_sample
  90. assert _equal(det_data_sample.gt_panoptic_seg.panoptic_seg,
  91. gt_panoptic_seg_data['panoptic_seg'])
  92. # test pred_panoptic_seg
  93. pred_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4))
  94. pred_panoptic_seg = PixelData(**pred_panoptic_seg_data)
  95. det_data_sample.pred_panoptic_seg = pred_panoptic_seg
  96. assert 'pred_panoptic_seg' in det_data_sample
  97. assert _equal(det_data_sample.pred_panoptic_seg.panoptic_seg,
  98. pred_panoptic_seg_data['panoptic_seg'])
  99. # test gt_sem_seg
  100. gt_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
  101. gt_segm_seg = PixelData(**gt_segm_seg_data)
  102. det_data_sample.gt_segm_seg = gt_segm_seg
  103. assert 'gt_segm_seg' in det_data_sample
  104. assert _equal(det_data_sample.gt_segm_seg.segm_seg,
  105. gt_segm_seg_data['segm_seg'])
  106. # test pred_segm_seg
  107. pred_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
  108. pred_segm_seg = PixelData(**pred_segm_seg_data)
  109. det_data_sample.pred_segm_seg = pred_segm_seg
  110. assert 'pred_segm_seg' in det_data_sample
  111. assert _equal(det_data_sample.pred_segm_seg.segm_seg,
  112. pred_segm_seg_data['segm_seg'])
  113. # test type error
  114. with pytest.raises(AssertionError):
  115. det_data_sample.pred_instances = torch.rand(2, 4)
  116. with pytest.raises(AssertionError):
  117. det_data_sample.pred_panoptic_seg = torch.rand(2, 4)
  118. with pytest.raises(AssertionError):
  119. det_data_sample.pred_sem_seg = torch.rand(2, 4)
  120. def test_deleter(self):
  121. gt_instances_data = dict(
  122. bboxes=torch.rand(4, 4),
  123. labels=torch.rand(4),
  124. masks=np.random.rand(4, 2, 2))
  125. det_data_sample = DetDataSample()
  126. gt_instances = InstanceData(data=gt_instances_data)
  127. det_data_sample.gt_instances = gt_instances
  128. assert 'gt_instances' in det_data_sample
  129. del det_data_sample.gt_instances
  130. assert 'gt_instances' not in det_data_sample
  131. pred_panoptic_seg_data = torch.rand(5, 4)
  132. pred_panoptic_seg = PixelData(data=pred_panoptic_seg_data)
  133. det_data_sample.pred_panoptic_seg = pred_panoptic_seg
  134. assert 'pred_panoptic_seg' in det_data_sample
  135. del det_data_sample.pred_panoptic_seg
  136. assert 'pred_panoptic_seg' not in det_data_sample
  137. pred_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
  138. pred_segm_seg = PixelData(**pred_segm_seg_data)
  139. det_data_sample.pred_segm_seg = pred_segm_seg
  140. assert 'pred_segm_seg' in det_data_sample
  141. del det_data_sample.pred_segm_seg
  142. assert 'pred_segm_seg' not in det_data_sample