test_standard_roi_head.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import unittest
  3. from unittest import TestCase
  4. import torch
  5. from mmengine.config import Config
  6. from parameterized import parameterized
  7. from mmdet.registry import MODELS
  8. from mmdet.testing import demo_mm_inputs, demo_mm_proposals
  9. from mmdet.utils import register_all_modules
  10. register_all_modules()
  11. def _fake_roi_head(with_shared_head=False):
  12. """Set a fake roi head config."""
  13. if not with_shared_head:
  14. roi_head = Config(
  15. dict(
  16. type='StandardRoIHead',
  17. bbox_roi_extractor=dict(
  18. type='SingleRoIExtractor',
  19. roi_layer=dict(
  20. type='RoIAlign', output_size=7, sampling_ratio=0),
  21. out_channels=1,
  22. featmap_strides=[4, 8, 16, 32]),
  23. bbox_head=dict(
  24. type='Shared2FCBBoxHead',
  25. in_channels=1,
  26. fc_out_channels=1,
  27. num_classes=4),
  28. mask_roi_extractor=dict(
  29. type='SingleRoIExtractor',
  30. roi_layer=dict(
  31. type='RoIAlign', output_size=14, sampling_ratio=0),
  32. out_channels=1,
  33. featmap_strides=[4, 8, 16, 32]),
  34. mask_head=dict(
  35. type='FCNMaskHead',
  36. num_convs=1,
  37. in_channels=1,
  38. conv_out_channels=1,
  39. num_classes=4),
  40. train_cfg=dict(
  41. assigner=dict(
  42. type='MaxIoUAssigner',
  43. pos_iou_thr=0.5,
  44. neg_iou_thr=0.5,
  45. min_pos_iou=0.5,
  46. match_low_quality=True,
  47. ignore_iof_thr=-1),
  48. sampler=dict(
  49. type='RandomSampler',
  50. num=512,
  51. pos_fraction=0.25,
  52. neg_pos_ub=-1,
  53. add_gt_as_proposals=True),
  54. mask_size=28,
  55. pos_weight=-1,
  56. debug=False),
  57. test_cfg=dict(
  58. score_thr=0.05,
  59. nms=dict(type='nms', iou_threshold=0.5),
  60. max_per_img=100,
  61. mask_thr_binary=0.5)))
  62. else:
  63. roi_head = Config(
  64. dict(
  65. type='StandardRoIHead',
  66. shared_head=dict(
  67. type='ResLayer',
  68. depth=50,
  69. stage=3,
  70. stride=2,
  71. dilation=1,
  72. style='caffe',
  73. norm_cfg=dict(type='BN', requires_grad=False),
  74. norm_eval=True),
  75. bbox_roi_extractor=dict(
  76. type='SingleRoIExtractor',
  77. roi_layer=dict(
  78. type='RoIAlign', output_size=14, sampling_ratio=0),
  79. out_channels=1,
  80. featmap_strides=[16]),
  81. bbox_head=dict(
  82. type='BBoxHead',
  83. with_avg_pool=True,
  84. in_channels=2048,
  85. roi_feat_size=7,
  86. num_classes=4),
  87. mask_roi_extractor=None,
  88. mask_head=dict(
  89. type='FCNMaskHead',
  90. num_convs=0,
  91. in_channels=2048,
  92. conv_out_channels=1,
  93. num_classes=4),
  94. train_cfg=dict(
  95. assigner=dict(
  96. type='MaxIoUAssigner',
  97. pos_iou_thr=0.5,
  98. neg_iou_thr=0.5,
  99. min_pos_iou=0.5,
  100. match_low_quality=False,
  101. ignore_iof_thr=-1),
  102. sampler=dict(
  103. type='RandomSampler',
  104. num=512,
  105. pos_fraction=0.25,
  106. neg_pos_ub=-1,
  107. add_gt_as_proposals=True),
  108. mask_size=14,
  109. pos_weight=-1,
  110. debug=False),
  111. test_cfg=dict(
  112. score_thr=0.05,
  113. nms=dict(type='nms', iou_threshold=0.5),
  114. max_per_img=100,
  115. mask_thr_binary=0.5)))
  116. return roi_head
  117. class TestStandardRoIHead(TestCase):
  118. def test_init(self):
  119. """Test init standard RoI head."""
  120. # Normal Mask R-CNN RoI head
  121. roi_head_cfg = _fake_roi_head()
  122. roi_head = MODELS.build(roi_head_cfg)
  123. self.assertTrue(roi_head.with_bbox)
  124. self.assertTrue(roi_head.with_mask)
  125. # Mask R-CNN RoI head with shared_head
  126. roi_head_cfg = _fake_roi_head(with_shared_head=True)
  127. roi_head = MODELS.build(roi_head_cfg)
  128. self.assertTrue(roi_head.with_bbox)
  129. self.assertTrue(roi_head.with_mask)
  130. self.assertTrue(roi_head.with_shared_head)
  131. @parameterized.expand([(False, ), (True, )])
  132. def test_standard_roi_head_loss(self, with_shared_head):
  133. """Tests standard roi head loss when truth is empty and non-empty."""
  134. if not torch.cuda.is_available():
  135. # RoI pooling only support in GPU
  136. return unittest.skip('test requires GPU and torch+cuda')
  137. s = 256
  138. roi_head_cfg = _fake_roi_head(with_shared_head=with_shared_head)
  139. roi_head = MODELS.build(roi_head_cfg)
  140. roi_head = roi_head.cuda()
  141. feats = []
  142. for i in range(len(roi_head.bbox_roi_extractor.featmap_strides)):
  143. if not with_shared_head:
  144. feats.append(
  145. torch.rand(1, 1, s // (2**(i + 2)),
  146. s // (2**(i + 2))).to(device='cuda'))
  147. else:
  148. feats.append(
  149. torch.rand(1, 1024, s // (2**(i + 2)),
  150. s // (2**(i + 2))).to(device='cuda'))
  151. feats = tuple(feats)
  152. # When truth is non-empty then both cls, box, and mask loss
  153. # should be nonzero for random inputs
  154. image_shapes = [(3, s, s)]
  155. batch_data_samples = demo_mm_inputs(
  156. batch_size=1,
  157. image_shapes=image_shapes,
  158. num_items=[1],
  159. num_classes=4,
  160. with_mask=True,
  161. device='cuda')['data_samples']
  162. proposals_list = demo_mm_proposals(
  163. image_shapes=image_shapes, num_proposals=100, device='cuda')
  164. out = roi_head.loss(feats, proposals_list, batch_data_samples)
  165. loss_cls = out['loss_cls']
  166. loss_bbox = out['loss_bbox']
  167. loss_mask = out['loss_mask']
  168. self.assertGreater(loss_cls.sum(), 0, 'cls loss should be non-zero')
  169. self.assertGreater(loss_bbox.sum(), 0, 'box loss should be non-zero')
  170. self.assertGreater(loss_mask.sum(), 0, 'mask loss should be non-zero')
  171. # When there is no truth, the cls loss should be nonzero but
  172. # there should be no box and mask loss.
  173. batch_data_samples = demo_mm_inputs(
  174. batch_size=1,
  175. image_shapes=image_shapes,
  176. num_items=[0],
  177. num_classes=4,
  178. with_mask=True,
  179. device='cuda')['data_samples']
  180. proposals_list = demo_mm_proposals(
  181. image_shapes=image_shapes, num_proposals=100, device='cuda')
  182. out = roi_head.loss(feats, proposals_list, batch_data_samples)
  183. empty_cls_loss = out['loss_cls']
  184. empty_bbox_loss = out['loss_bbox']
  185. empty_mask_loss = out['loss_mask']
  186. self.assertGreater(empty_cls_loss.sum(), 0,
  187. 'cls loss should be non-zero')
  188. self.assertEqual(
  189. empty_bbox_loss.sum(), 0,
  190. 'there should be no box loss when there are no true boxes')
  191. self.assertEqual(
  192. empty_mask_loss.sum(), 0,
  193. 'there should be no mask loss when there are no true boxes')