test_cascade_rpn_head.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.config import ConfigDict
  5. from mmengine.structures import InstanceData
  6. from mmdet.models.dense_heads import CascadeRPNHead
  7. from mmdet.structures import DetDataSample
  8. rpn_weight = 0.7
  9. cascade_rpn_config = ConfigDict(
  10. dict(
  11. num_stages=2,
  12. num_classes=1,
  13. stages=[
  14. dict(
  15. type='StageCascadeRPNHead',
  16. in_channels=1,
  17. feat_channels=1,
  18. anchor_generator=dict(
  19. type='AnchorGenerator',
  20. scales=[8],
  21. ratios=[1.0],
  22. strides=[4, 8, 16, 32, 64]),
  23. adapt_cfg=dict(type='dilation', dilation=3),
  24. bridged_feature=True,
  25. with_cls=False,
  26. reg_decoded_bbox=True,
  27. bbox_coder=dict(
  28. type='DeltaXYWHBBoxCoder',
  29. target_means=(.0, .0, .0, .0),
  30. target_stds=(0.1, 0.1, 0.5, 0.5)),
  31. loss_bbox=dict(
  32. type='IoULoss', linear=True,
  33. loss_weight=10.0 * rpn_weight)),
  34. dict(
  35. type='StageCascadeRPNHead',
  36. in_channels=1,
  37. feat_channels=1,
  38. adapt_cfg=dict(type='offset'),
  39. bridged_feature=False,
  40. with_cls=True,
  41. reg_decoded_bbox=True,
  42. bbox_coder=dict(
  43. type='DeltaXYWHBBoxCoder',
  44. target_means=(.0, .0, .0, .0),
  45. target_stds=(0.05, 0.05, 0.1, 0.1)),
  46. loss_cls=dict(
  47. type='CrossEntropyLoss',
  48. use_sigmoid=True,
  49. loss_weight=1.0 * rpn_weight),
  50. loss_bbox=dict(
  51. type='IoULoss', linear=True,
  52. loss_weight=10.0 * rpn_weight))
  53. ],
  54. train_cfg=[
  55. dict(
  56. assigner=dict(
  57. type='RegionAssigner', center_ratio=0.2, ignore_ratio=0.5),
  58. allowed_border=-1,
  59. pos_weight=-1,
  60. debug=False),
  61. dict(
  62. assigner=dict(
  63. type='MaxIoUAssigner',
  64. pos_iou_thr=0.7,
  65. neg_iou_thr=0.7,
  66. min_pos_iou=0.3,
  67. ignore_iof_thr=-1),
  68. sampler=dict(
  69. type='RandomSampler',
  70. num=256,
  71. pos_fraction=0.5,
  72. neg_pos_ub=-1,
  73. add_gt_as_proposals=False),
  74. allowed_border=-1,
  75. pos_weight=-1,
  76. debug=False)
  77. ],
  78. test_cfg=dict(max_per_img=300, nms=dict(iou_threshold=0.8))))
  79. class TestStageCascadeRPNHead(TestCase):
  80. def test_cascade_rpn_head_loss(self):
  81. """Tests cascade rpn head loss when truth is empty and non-empty."""
  82. cascade_rpn_head = CascadeRPNHead(**cascade_rpn_config)
  83. s = 256
  84. feats = [
  85. torch.rand(1, 1, s // stride[1], s // stride[0])
  86. for stride in cascade_rpn_head.stages[0].prior_generator.strides
  87. ]
  88. img_metas = {
  89. 'img_shape': (s, s),
  90. 'pad_shape': (s, s),
  91. 'scale_factor': 1,
  92. }
  93. sample = DetDataSample()
  94. sample.set_metainfo(img_metas)
  95. # Test that empty ground truth encourages the network to
  96. # predict background
  97. gt_instances = InstanceData()
  98. gt_instances.bboxes = torch.empty((0, 4))
  99. gt_instances.labels = torch.LongTensor([])
  100. sample.gt_instances = gt_instances
  101. empty_gt_losses = cascade_rpn_head.loss(feats, [sample])
  102. for key, loss in empty_gt_losses.items():
  103. loss = sum(loss)
  104. if 'cls' in key:
  105. self.assertGreater(loss.item(), 0,
  106. 'cls loss should be non-zero')
  107. elif 'reg' in key:
  108. self.assertEqual(
  109. loss.item(), 0,
  110. 'there should be no reg loss when no ground true boxes')
  111. # When truth is non-empty then all cls, box loss and centerness loss
  112. # should be nonzero for random inputs
  113. gt_instances = InstanceData()
  114. gt_instances.bboxes = torch.Tensor(
  115. [[23.6667, 23.8757, 238.6326, 151.8874]])
  116. gt_instances.labels = torch.LongTensor([0])
  117. sample.gt_instances = gt_instances
  118. one_gt_losses = cascade_rpn_head.loss(feats, [sample])
  119. for loss in one_gt_losses.values():
  120. loss = sum(loss)
  121. self.assertGreater(
  122. loss.item(), 0,
  123. 'cls loss, or box loss, or iou loss should be non-zero')
  124. def test_cascade_rpn_head_loss_and_predict(self):
  125. """Tests cascade rpn head loss and predict function."""
  126. cascade_rpn_head = CascadeRPNHead(**cascade_rpn_config)
  127. s = 256
  128. feats = [
  129. torch.rand(1, 1, s // stride[1], s // stride[0])
  130. for stride in cascade_rpn_head.stages[0].prior_generator.strides
  131. ]
  132. img_metas = {
  133. 'img_shape': (s, s),
  134. 'pad_shape': (s, s),
  135. 'scale_factor': 1,
  136. }
  137. sample = DetDataSample()
  138. sample.set_metainfo(img_metas)
  139. gt_instances = InstanceData()
  140. gt_instances.bboxes = torch.empty((0, 4))
  141. gt_instances.labels = torch.LongTensor([])
  142. sample.gt_instances = gt_instances
  143. proposal_cfg = ConfigDict(
  144. dict(max_per_img=300, nms=dict(iou_threshold=0.8)))
  145. cascade_rpn_head.loss_and_predict(feats, [sample], proposal_cfg)
  146. def test_cascade_rpn_head_predict(self):
  147. """Tests cascade rpn head predict function."""
  148. cascade_rpn_head = CascadeRPNHead(**cascade_rpn_config)
  149. s = 256
  150. feats = [
  151. torch.rand(1, 1, s // stride[1], s // stride[0])
  152. for stride in cascade_rpn_head.stages[0].prior_generator.strides
  153. ]
  154. img_metas = {
  155. 'img_shape': (s, s),
  156. 'pad_shape': (s, s),
  157. 'scale_factor': 1,
  158. }
  159. sample = DetDataSample()
  160. sample.set_metainfo(img_metas)
  161. cascade_rpn_head.predict(feats, [sample])