test_ga_retina_head.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.config import ConfigDict
  5. from mmdet.models.dense_heads import GARetinaHead
  6. ga_retina_head_config = ConfigDict(
  7. dict(
  8. num_classes=4,
  9. in_channels=4,
  10. feat_channels=4,
  11. stacked_convs=1,
  12. approx_anchor_generator=dict(
  13. type='AnchorGenerator',
  14. octave_base_scale=4,
  15. scales_per_octave=3,
  16. ratios=[0.5, 1.0, 2.0],
  17. strides=[8, 16, 32, 64, 128]),
  18. square_anchor_generator=dict(
  19. type='AnchorGenerator',
  20. ratios=[1.0],
  21. scales=[4],
  22. strides=[8, 16, 32, 64, 128]),
  23. anchor_coder=dict(
  24. type='DeltaXYWHBBoxCoder',
  25. target_means=[.0, .0, .0, .0],
  26. target_stds=[1.0, 1.0, 1.0, 1.0]),
  27. bbox_coder=dict(
  28. type='DeltaXYWHBBoxCoder',
  29. target_means=[.0, .0, .0, .0],
  30. target_stds=[1.0, 1.0, 1.0, 1.0]),
  31. loc_filter_thr=0.01,
  32. loss_loc=dict(
  33. type='FocalLoss',
  34. use_sigmoid=True,
  35. gamma=2.0,
  36. alpha=0.25,
  37. loss_weight=1.0),
  38. loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0),
  39. loss_cls=dict(
  40. type='FocalLoss',
  41. use_sigmoid=True,
  42. gamma=2.0,
  43. alpha=0.25,
  44. loss_weight=1.0),
  45. loss_bbox=dict(type='SmoothL1Loss', beta=0.04, loss_weight=1.0),
  46. train_cfg=dict(
  47. ga_assigner=dict(
  48. type='ApproxMaxIoUAssigner',
  49. pos_iou_thr=0.5,
  50. neg_iou_thr=0.4,
  51. min_pos_iou=0.4,
  52. ignore_iof_thr=-1),
  53. ga_sampler=dict(
  54. type='RandomSampler',
  55. num=256,
  56. pos_fraction=0.5,
  57. neg_pos_ub=-1,
  58. add_gt_as_proposals=False),
  59. assigner=dict(
  60. type='MaxIoUAssigner',
  61. pos_iou_thr=0.5,
  62. neg_iou_thr=0.5,
  63. min_pos_iou=0.0,
  64. ignore_iof_thr=-1),
  65. allowed_border=-1,
  66. pos_weight=-1,
  67. center_ratio=0.2,
  68. ignore_ratio=0.5,
  69. debug=False),
  70. test_cfg=dict(
  71. nms_pre=1000,
  72. min_bbox_size=0,
  73. score_thr=0.05,
  74. nms=dict(type='nms', iou_threshold=0.5),
  75. max_per_img=100)))
  76. class TestGARetinaHead(TestCase):
  77. def test_ga_retina_head_init_and_forward(self):
  78. """The GARetinaHead inherit loss and prediction function from
  79. GuidedAchorHead.
  80. Here, we only test GARetinaHet initialization and forward.
  81. """
  82. # Test initializaion
  83. ga_retina_head = GARetinaHead(**ga_retina_head_config)
  84. # Test forward
  85. s = 256
  86. feats = (
  87. torch.rand(1, 4, s // stride[1], s // stride[0])
  88. for stride in ga_retina_head.square_anchor_generator.strides)
  89. ga_retina_head(feats)