test_centernet_update_head.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.models.dense_heads import CenterNetUpdateHead
  6. class TestCenterNetUpdateHead(TestCase):
  7. def test_centernet_update_head_loss(self):
  8. """Tests fcos head loss when truth is empty and non-empty."""
  9. s = 256
  10. img_metas = [{
  11. 'img_shape': (s, s, 3),
  12. 'pad_shape': (s, s, 3),
  13. 'scale_factor': 1,
  14. }]
  15. centernet_head = CenterNetUpdateHead(
  16. num_classes=4,
  17. in_channels=1,
  18. feat_channels=1,
  19. stacked_convs=1,
  20. norm_cfg=None)
  21. # Fcos head expects a multiple levels of features per image
  22. feats = (
  23. torch.rand(1, 1, s // stride[1], s // stride[0])
  24. for stride in centernet_head.prior_generator.strides)
  25. cls_scores, bbox_preds = centernet_head.forward(feats)
  26. # Test that empty ground truth encourages the network to
  27. # predict background
  28. gt_instances = InstanceData()
  29. gt_instances.bboxes = torch.empty((0, 4))
  30. gt_instances.labels = torch.LongTensor([])
  31. empty_gt_losses = centernet_head.loss_by_feat(cls_scores, bbox_preds,
  32. [gt_instances],
  33. img_metas)
  34. # When there is no truth, the cls loss should be nonzero but
  35. # box loss and centerness loss should be zero
  36. empty_cls_loss = empty_gt_losses['loss_cls'].item()
  37. empty_box_loss = empty_gt_losses['loss_bbox'].item()
  38. self.assertGreater(empty_cls_loss, 0, 'cls loss should be non-zero')
  39. self.assertEqual(
  40. empty_box_loss, 0,
  41. 'there should be no box loss when there are no true boxes')
  42. # When truth is non-empty then all cls, box loss and centerness loss
  43. # should be nonzero for random inputs
  44. gt_instances = InstanceData()
  45. gt_instances.bboxes = torch.Tensor(
  46. [[23.6667, 23.8757, 238.6326, 151.8874]])
  47. gt_instances.labels = torch.LongTensor([2])
  48. one_gt_losses = centernet_head.loss_by_feat(cls_scores, bbox_preds,
  49. [gt_instances], img_metas)
  50. onegt_cls_loss = one_gt_losses['loss_cls'].item()
  51. onegt_box_loss = one_gt_losses['loss_bbox'].item()
  52. self.assertGreater(onegt_cls_loss, 0, 'cls loss should be non-zero')
  53. self.assertGreater(onegt_box_loss, 0, 'box loss should be non-zero')