test_retina_sepBN_head.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine import Config
  5. from mmengine.structures import InstanceData
  6. from mmdet import * # noqa
  7. from mmdet.models.dense_heads import RetinaSepBNHead
  8. class TestRetinaSepBNHead(TestCase):
  9. def test_init(self):
  10. """Test init RetinaSepBN head."""
  11. anchor_head = RetinaSepBNHead(num_classes=1, num_ins=1, in_channels=1)
  12. anchor_head.init_weights()
  13. self.assertTrue(anchor_head.cls_convs)
  14. self.assertTrue(anchor_head.reg_convs)
  15. self.assertTrue(anchor_head.retina_cls)
  16. self.assertTrue(anchor_head.retina_reg)
  17. def test_retina_sepbn_head_loss(self):
  18. """Tests RetinaSepBN head loss when truth is empty and non-empty."""
  19. s = 256
  20. img_metas = [{
  21. 'img_shape': (s, s, 3),
  22. 'pad_shape': (s, s, 3),
  23. 'scale_factor': 1,
  24. }]
  25. cfg = Config(
  26. dict(
  27. assigner=dict(
  28. type='MaxIoUAssigner',
  29. pos_iou_thr=0.5,
  30. neg_iou_thr=0.4,
  31. min_pos_iou=0,
  32. ignore_iof_thr=-1),
  33. sampler=dict(type='PseudoSampler'
  34. ), # Focal loss should use PseudoSampler
  35. allowed_border=-1,
  36. pos_weight=-1,
  37. debug=False))
  38. anchor_head = RetinaSepBNHead(
  39. num_classes=4, num_ins=5, in_channels=1, train_cfg=cfg)
  40. # Anchor head expects a multiple levels of features per image
  41. feats = []
  42. for i in range(len(anchor_head.prior_generator.strides)):
  43. feats.append(
  44. torch.rand(1, 1, s // (2**(i + 2)), s // (2**(i + 2))))
  45. cls_scores, bbox_preds = anchor_head.forward(tuple(feats))
  46. # Test that empty ground truth encourages the network to
  47. # predict background
  48. gt_instances = InstanceData()
  49. gt_instances.bboxes = torch.empty((0, 4))
  50. gt_instances.labels = torch.LongTensor([])
  51. empty_gt_losses = anchor_head.loss_by_feat(cls_scores, bbox_preds,
  52. [gt_instances], img_metas)
  53. # When there is no truth, the cls loss should be nonzero but
  54. # there should be no box loss.
  55. empty_cls_loss = sum(empty_gt_losses['loss_cls'])
  56. empty_box_loss = sum(empty_gt_losses['loss_bbox'])
  57. self.assertGreater(empty_cls_loss.item(), 0,
  58. 'cls loss should be non-zero')
  59. self.assertEqual(
  60. empty_box_loss.item(), 0,
  61. 'there should be no box loss when there are no true boxes')
  62. # When truth is non-empty then both cls and box loss
  63. # should be nonzero for random inputs
  64. gt_instances = InstanceData()
  65. gt_instances.bboxes = torch.Tensor(
  66. [[23.6667, 23.8757, 238.6326, 151.8874]])
  67. gt_instances.labels = torch.LongTensor([2])
  68. one_gt_losses = anchor_head.loss_by_feat(cls_scores, bbox_preds,
  69. [gt_instances], img_metas)
  70. onegt_cls_loss = sum(one_gt_losses['loss_cls'])
  71. onegt_box_loss = sum(one_gt_losses['loss_bbox'])
  72. self.assertGreater(onegt_cls_loss.item(), 0,
  73. 'cls loss should be non-zero')
  74. self.assertGreater(onegt_box_loss.item(), 0,
  75. 'box loss should be non-zero')