test_anchor_head.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine import Config
  5. from mmengine.structures import InstanceData
  6. from mmdet import * # noqa
  7. from mmdet.models.dense_heads import AnchorHead
  8. class TestAnchorHead(TestCase):
  9. def test_anchor_head_loss(self):
  10. """Tests anchor head loss when truth is empty and non-empty."""
  11. s = 256
  12. img_metas = [{
  13. 'img_shape': (s, s, 3),
  14. 'pad_shape': (s, s, 3),
  15. 'scale_factor': 1,
  16. }]
  17. cfg = Config(
  18. dict(
  19. assigner=dict(
  20. type='MaxIoUAssigner',
  21. pos_iou_thr=0.7,
  22. neg_iou_thr=0.3,
  23. min_pos_iou=0.3,
  24. match_low_quality=True,
  25. ignore_iof_thr=-1),
  26. sampler=dict(
  27. type='RandomSampler',
  28. num=256,
  29. pos_fraction=0.5,
  30. neg_pos_ub=-1,
  31. add_gt_as_proposals=False),
  32. allowed_border=0,
  33. pos_weight=-1,
  34. debug=False))
  35. anchor_head = AnchorHead(num_classes=4, in_channels=1, train_cfg=cfg)
  36. # Anchor head expects a multiple levels of features per image
  37. feats = (
  38. torch.rand(1, 1, s // (2**(i + 2)), s // (2**(i + 2)))
  39. for i in range(len(anchor_head.prior_generator.strides)))
  40. cls_scores, bbox_preds = anchor_head.forward(feats)
  41. # Test that empty ground truth encourages the network to
  42. # predict background
  43. gt_instances = InstanceData()
  44. gt_instances.bboxes = torch.empty((0, 4))
  45. gt_instances.labels = torch.LongTensor([])
  46. empty_gt_losses = anchor_head.loss_by_feat(cls_scores, bbox_preds,
  47. [gt_instances], img_metas)
  48. # When there is no truth, the cls loss should be nonzero but
  49. # there should be no box loss.
  50. empty_cls_loss = sum(empty_gt_losses['loss_cls'])
  51. empty_box_loss = sum(empty_gt_losses['loss_bbox'])
  52. assert empty_cls_loss.item() > 0, 'cls loss should be non-zero'
  53. assert empty_box_loss.item() == 0, (
  54. 'there should be no box loss when there are no true boxes')
  55. # When truth is non-empty then both cls and box loss
  56. # should be nonzero for random inputs
  57. gt_instances = InstanceData()
  58. gt_instances.bboxes = torch.Tensor(
  59. [[23.6667, 23.8757, 238.6326, 151.8874]])
  60. gt_instances.labels = torch.LongTensor([2])
  61. one_gt_losses = anchor_head.loss_by_feat(cls_scores, bbox_preds,
  62. [gt_instances], img_metas)
  63. onegt_cls_loss = sum(one_gt_losses['loss_cls'])
  64. onegt_box_loss = sum(one_gt_losses['loss_bbox'])
  65. assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero'
  66. assert onegt_box_loss.item() > 0, 'box loss should be non-zero'