test_ga_rpn_head.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.config import ConfigDict
  5. from mmengine.structures import InstanceData
  6. from mmdet.models.dense_heads import GARPNHead
  7. ga_rpn_config = ConfigDict(
  8. dict(
  9. num_classes=1,
  10. in_channels=4,
  11. feat_channels=4,
  12. approx_anchor_generator=dict(
  13. type='AnchorGenerator',
  14. octave_base_scale=8,
  15. scales_per_octave=3,
  16. ratios=[0.5, 1.0, 2.0],
  17. strides=[4, 8, 16, 32, 64]),
  18. square_anchor_generator=dict(
  19. type='AnchorGenerator',
  20. ratios=[1.0],
  21. scales=[8],
  22. strides=[4, 8, 16, 32, 64]),
  23. anchor_coder=dict(
  24. type='DeltaXYWHBBoxCoder',
  25. target_means=[.0, .0, .0, .0],
  26. target_stds=[0.07, 0.07, 0.14, 0.14]),
  27. bbox_coder=dict(
  28. type='DeltaXYWHBBoxCoder',
  29. target_means=[.0, .0, .0, .0],
  30. target_stds=[0.07, 0.07, 0.11, 0.11]),
  31. loc_filter_thr=0.01,
  32. loss_loc=dict(
  33. type='FocalLoss',
  34. use_sigmoid=True,
  35. gamma=2.0,
  36. alpha=0.25,
  37. loss_weight=1.0),
  38. loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0),
  39. loss_cls=dict(
  40. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
  41. loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0),
  42. train_cfg=dict(
  43. ga_assigner=dict(
  44. type='ApproxMaxIoUAssigner',
  45. pos_iou_thr=0.7,
  46. neg_iou_thr=0.3,
  47. min_pos_iou=0.3,
  48. ignore_iof_thr=-1),
  49. ga_sampler=dict(
  50. type='RandomSampler',
  51. num=256,
  52. pos_fraction=0.5,
  53. neg_pos_ub=-1,
  54. add_gt_as_proposals=False),
  55. assigner=dict(
  56. type='MaxIoUAssigner',
  57. pos_iou_thr=0.7,
  58. neg_iou_thr=0.3,
  59. min_pos_iou=0.3,
  60. match_low_quality=True,
  61. ignore_iof_thr=-1),
  62. sampler=dict(
  63. type='RandomSampler',
  64. num=256,
  65. pos_fraction=0.5,
  66. neg_pos_ub=-1,
  67. add_gt_as_proposals=False),
  68. allowed_border=-1,
  69. center_ratio=0.2,
  70. ignore_ratio=0.5,
  71. pos_weight=-1,
  72. debug=False),
  73. test_cfg=dict(
  74. nms_pre=1000,
  75. ms_post=1000,
  76. max_per_img=300,
  77. nms=dict(type='nms', iou_threshold=0.7),
  78. min_bbox_size=0)))
  79. class TestGARPNHead(TestCase):
  80. def test_ga_rpn_head_loss(self):
  81. """Tests ga rpn head loss."""
  82. s = 256
  83. img_metas = [{
  84. 'img_shape': (s, s),
  85. 'pad_shape': (s, s),
  86. 'scale_factor': (1, 1)
  87. }]
  88. ga_rpn_head = GARPNHead(**ga_rpn_config)
  89. feats = (
  90. torch.rand(1, 4, s // stride[1], s // stride[0])
  91. for stride in ga_rpn_head.square_anchor_generator.strides)
  92. outs = ga_rpn_head(feats)
  93. # When truth is non-empty then all cls, box loss and centerness loss
  94. # should be nonzero for random inputs
  95. gt_instances = InstanceData()
  96. gt_instances.bboxes = torch.Tensor(
  97. [[23.6667, 23.8757, 238.6326, 151.8874]])
  98. gt_instances.labels = torch.LongTensor([0])
  99. one_gt_losses = ga_rpn_head.loss_by_feat(*outs, [gt_instances],
  100. img_metas)
  101. onegt_cls_loss = sum(one_gt_losses['loss_rpn_cls']).item()
  102. onegt_box_loss = sum(one_gt_losses['loss_rpn_bbox']).item()
  103. onegt_shape_loss = sum(one_gt_losses['loss_anchor_shape']).item()
  104. onegt_loc_loss = sum(one_gt_losses['loss_anchor_loc']).item()
  105. self.assertGreater(onegt_cls_loss, 0, 'cls loss should be non-zero')
  106. self.assertGreater(onegt_box_loss, 0, 'box loss should be non-zero')
  107. self.assertGreater(onegt_shape_loss, 0,
  108. 'shape loss should be non-zero')
  109. self.assertGreater(onegt_loc_loss, 0,
  110. 'location loss should be non-zero')
  111. def test_ga_rpn_head_predict_by_feat(self):
  112. s = 256
  113. img_metas = [{
  114. 'img_shape': (s, s),
  115. 'pad_shape': (s, s),
  116. 'scale_factor': (1, 1)
  117. }]
  118. ga_rpn_head = GARPNHead(**ga_rpn_config)
  119. feats = (
  120. torch.rand(1, 4, s // stride[1], s // stride[0])
  121. for stride in ga_rpn_head.square_anchor_generator.strides)
  122. outs = ga_rpn_head(feats)
  123. cfg = ConfigDict(
  124. dict(
  125. nms_pre=2000,
  126. nms_post=1000,
  127. max_per_img=300,
  128. nms=dict(type='nms', iou_threshold=0.7),
  129. min_bbox_size=0))
  130. ga_rpn_head.predict_by_feat(
  131. *outs, batch_img_metas=img_metas, cfg=cfg, rescale=True)