test_wrappers.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import copy
  2. import os.path as osp
  3. import unittest
  4. from mmcv.transforms import Compose
  5. from mmdet.datasets.transforms import MultiBranch, RandomOrder
  6. from mmdet.utils import register_all_modules
  7. from .utils import construct_toy_data
  8. register_all_modules()
  9. class TestMultiBranch(unittest.TestCase):
  10. def setUp(self):
  11. """Setup the model and optimizer which are used in every test method.
  12. TestCase calls functions in this order: setUp() -> testMethod() ->
  13. tearDown() -> cleanUp()
  14. """
  15. data_prefix = osp.join(osp.dirname(__file__), '../../data')
  16. img_path = osp.join(data_prefix, 'color.jpg')
  17. seg_map = osp.join(data_prefix, 'gray.jpg')
  18. self.meta_keys = ('img_id', 'img_path', 'ori_shape', 'img_shape',
  19. 'scale_factor', 'flip', 'flip_direction',
  20. 'homography_matrix')
  21. self.results = {
  22. 'img_path':
  23. img_path,
  24. 'img_id':
  25. 12345,
  26. 'img_shape': (300, 400),
  27. 'seg_map_path':
  28. seg_map,
  29. 'instances': [{
  30. 'bbox': [0, 0, 10, 20],
  31. 'bbox_label': 1,
  32. 'mask': [[0, 0, 0, 20, 10, 20, 10, 0]],
  33. 'ignore_flag': 0
  34. }, {
  35. 'bbox': [10, 10, 110, 120],
  36. 'bbox_label': 2,
  37. 'mask': [[10, 10, 110, 10, 110, 120, 110, 10]],
  38. 'ignore_flag': 0
  39. }, {
  40. 'bbox': [50, 50, 60, 80],
  41. 'bbox_label': 2,
  42. 'mask': [[50, 50, 60, 50, 60, 80, 50, 80]],
  43. 'ignore_flag': 1
  44. }]
  45. }
  46. self.branch_field = ['sup', 'sup_teacher', 'sup_student']
  47. self.weak_pipeline = [
  48. dict(type='ShearX'),
  49. dict(type='PackDetInputs', meta_keys=self.meta_keys)
  50. ]
  51. self.strong_pipeline = [
  52. dict(type='ShearX'),
  53. dict(type='ShearY'),
  54. dict(type='PackDetInputs', meta_keys=self.meta_keys)
  55. ]
  56. self.labeled_pipeline = [
  57. dict(type='LoadImageFromFile'),
  58. dict(
  59. type='LoadAnnotations',
  60. with_bbox=True,
  61. with_mask=True,
  62. with_seg=True),
  63. dict(type='Resize', scale=(1333, 800), keep_ratio=True),
  64. dict(type='RandomFlip', prob=0.5),
  65. dict(
  66. type='MultiBranch',
  67. branch_field=self.branch_field,
  68. sup_teacher=self.weak_pipeline,
  69. sup_student=self.strong_pipeline),
  70. ]
  71. self.unlabeled_pipeline = [
  72. dict(type='LoadImageFromFile'),
  73. dict(type='Resize', scale=(1333, 800), keep_ratio=True),
  74. dict(type='RandomFlip', prob=0.5),
  75. dict(
  76. type='MultiBranch',
  77. branch_field=self.branch_field,
  78. unsup_teacher=self.weak_pipeline,
  79. unsup_student=self.strong_pipeline),
  80. ]
  81. def test_transform(self):
  82. labeled_pipeline = Compose(self.labeled_pipeline)
  83. labeled_results = labeled_pipeline(copy.deepcopy(self.results))
  84. unlabeled_pipeline = Compose(self.unlabeled_pipeline)
  85. unlabeled_results = unlabeled_pipeline(copy.deepcopy(self.results))
  86. # test branch sup_teacher and sup_student
  87. sup_branches = ['sup_teacher', 'sup_student']
  88. for branch in sup_branches:
  89. self.assertIn(branch, labeled_results['data_samples'])
  90. self.assertIn('homography_matrix',
  91. labeled_results['data_samples'][branch])
  92. self.assertIn('labels',
  93. labeled_results['data_samples'][branch].gt_instances)
  94. self.assertIn('bboxes',
  95. labeled_results['data_samples'][branch].gt_instances)
  96. self.assertIn('masks',
  97. labeled_results['data_samples'][branch].gt_instances)
  98. self.assertIn('gt_sem_seg',
  99. labeled_results['data_samples'][branch])
  100. # test branch unsup_teacher and unsup_student
  101. unsup_branches = ['unsup_teacher', 'unsup_student']
  102. for branch in unsup_branches:
  103. self.assertIn(branch, unlabeled_results['data_samples'])
  104. self.assertIn('homography_matrix',
  105. unlabeled_results['data_samples'][branch])
  106. self.assertNotIn(
  107. 'labels',
  108. unlabeled_results['data_samples'][branch].gt_instances)
  109. self.assertNotIn(
  110. 'bboxes',
  111. unlabeled_results['data_samples'][branch].gt_instances)
  112. self.assertNotIn(
  113. 'masks',
  114. unlabeled_results['data_samples'][branch].gt_instances)
  115. self.assertNotIn('gt_sem_seg',
  116. unlabeled_results['data_samples'][branch])
  117. def test_repr(self):
  118. pipeline = [dict(type='PackDetInputs', meta_keys=())]
  119. transform = MultiBranch(
  120. branch_field=self.branch_field, sup=pipeline, unsup=pipeline)
  121. self.assertEqual(
  122. repr(transform),
  123. ("MultiBranch(branch_pipelines=['sup', 'unsup'])"))
  124. class TestRandomOrder(unittest.TestCase):
  125. def setUp(self):
  126. """Setup the model and optimizer which are used in every test method.
  127. TestCase calls functions in this order: setUp() -> testMethod() ->
  128. tearDown() -> cleanUp()
  129. """
  130. self.results = construct_toy_data(poly2mask=True)
  131. self.pipeline = [
  132. dict(type='Sharpness'),
  133. dict(type='Contrast'),
  134. dict(type='Brightness'),
  135. dict(type='Rotate'),
  136. dict(type='ShearX'),
  137. dict(type='TranslateY')
  138. ]
  139. def test_transform(self):
  140. transform = RandomOrder(self.pipeline)
  141. results = transform(copy.deepcopy(self.results))
  142. self.assertEqual(results['img_shape'], self.results['img_shape'])
  143. self.assertEqual(results['gt_bboxes'].shape,
  144. self.results['gt_bboxes'].shape)
  145. self.assertEqual(results['gt_bboxes_labels'],
  146. self.results['gt_bboxes_labels'])
  147. self.assertEqual(results['gt_ignore_flags'],
  148. self.results['gt_ignore_flags'])
  149. self.assertEqual(results['gt_masks'].masks.shape,
  150. self.results['gt_masks'].masks.shape)
  151. self.assertEqual(results['gt_seg_map'].shape,
  152. self.results['gt_seg_map'].shape)
  153. def test_repr(self):
  154. transform = RandomOrder(self.pipeline)
  155. self.assertEqual(
  156. repr(transform), ('RandomOrder(Sharpness, Contrast, '
  157. 'Brightness, Rotate, ShearX, TranslateY, )'))