test_pisa_retinanet_head.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from math import ceil
  3. from unittest import TestCase
  4. import torch
  5. from mmengine import Config
  6. from mmengine.structures import InstanceData
  7. from mmdet import * # noqa
  8. from mmdet.models.dense_heads import PISARetinaHead
  9. class TestPISARetinaHead(TestCase):
  10. def test_pisa_reitnanet_head_loss(self):
  11. """Tests pisa retinanet head loss when truth is empty and non-empty."""
  12. s = 300
  13. img_metas = [{
  14. 'img_shape': (s, s),
  15. 'pad_shape': (s, s),
  16. 'scale_factor': 1,
  17. }]
  18. cfg = Config(
  19. dict(
  20. assigner=dict(
  21. type='MaxIoUAssigner',
  22. pos_iou_thr=0.5,
  23. neg_iou_thr=0.4,
  24. min_pos_iou=0,
  25. ignore_iof_thr=-1),
  26. isr=dict(k=2., bias=0.),
  27. carl=dict(k=1., bias=0.2),
  28. sampler=dict(type='PseudoSampler'),
  29. allowed_border=-1,
  30. pos_weight=-1,
  31. debug=False))
  32. pisa_retinanet_head = PISARetinaHead(
  33. num_classes=4,
  34. in_channels=1,
  35. stacked_convs=1,
  36. feat_channels=256,
  37. anchor_generator=dict(
  38. type='AnchorGenerator',
  39. octave_base_scale=4,
  40. scales_per_octave=3,
  41. ratios=[0.5, 1.0, 2.0],
  42. strides=[8, 16, 32, 64, 128]),
  43. bbox_coder=dict(
  44. type='DeltaXYWHBBoxCoder',
  45. target_means=[.0, .0, .0, .0],
  46. target_stds=[1.0, 1.0, 1.0, 1.0]),
  47. loss_cls=dict(
  48. type='FocalLoss',
  49. use_sigmoid=True,
  50. gamma=2.0,
  51. alpha=0.25,
  52. loss_weight=1.0),
  53. loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0),
  54. train_cfg=cfg)
  55. # pisa retina head expects a multiple levels of features per image
  56. feats = (
  57. torch.rand(1, 1, ceil(s / stride[0]), ceil(s / stride[0]))
  58. for stride in pisa_retinanet_head.prior_generator.strides)
  59. cls_scores, bbox_preds = pisa_retinanet_head.forward(feats)
  60. # Test that empty ground truth encourages the network to
  61. # predict background
  62. gt_instances = InstanceData()
  63. gt_instances.bboxes = torch.empty((0, 4))
  64. gt_instances.labels = torch.LongTensor([])
  65. empty_gt_losses = pisa_retinanet_head.loss_by_feat(
  66. cls_scores, bbox_preds, [gt_instances], img_metas)
  67. # When there is no truth, cls_loss and box_loss should all be zero.
  68. empty_cls_loss = empty_gt_losses['loss_cls']
  69. empty_box_loss = empty_gt_losses['loss_bbox']
  70. empty_carl_loss = empty_gt_losses['loss_carl']
  71. self.assertGreater(empty_cls_loss.item(), 0,
  72. 'cls loss should be non-zero')
  73. self.assertEqual(
  74. empty_box_loss.item(), 0,
  75. 'there should be no box loss when there are no true boxes')
  76. self.assertEqual(
  77. empty_carl_loss.item(), 0,
  78. 'there should be no carl loss when there are no true boxes')
  79. # When truth is non-empty then both cls and box loss
  80. # should be nonzero for random inputs
  81. gt_instances = InstanceData()
  82. gt_instances.bboxes = torch.Tensor(
  83. [[23.6667, 23.8757, 238.6326, 151.8874]])
  84. gt_instances.labels = torch.LongTensor([2])
  85. one_gt_losses = pisa_retinanet_head.loss_by_feat(
  86. cls_scores, bbox_preds, [gt_instances], img_metas)
  87. onegt_cls_loss = one_gt_losses['loss_cls']
  88. onegt_box_loss = one_gt_losses['loss_bbox']
  89. onegt_carl_loss = one_gt_losses['loss_carl']
  90. self.assertGreater(onegt_cls_loss.item(), 0,
  91. 'cls loss should be non-zero')
  92. self.assertGreater(onegt_box_loss.item(), 0,
  93. 'box loss should be non-zero')
  94. self.assertGreater(onegt_carl_loss.item(), 0,
  95. 'carl loss should be non-zero')