test_deformable_detr.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.registry import MODELS
  6. from mmdet.structures import DetDataSample
  7. from mmdet.testing import get_detector_cfg
  8. from mmdet.utils import register_all_modules
  9. class TestDeformableDETR(TestCase):
  10. def setUp(self):
  11. register_all_modules()
  12. def test_deformable_detr_head_loss(self):
  13. """Tests transformer head loss when truth is empty and non-empty."""
  14. s = 256
  15. metainfo = {
  16. 'img_shape': (s, s),
  17. 'scale_factor': (1, 1),
  18. 'pad_shape': (s, s),
  19. 'batch_input_shape': (s, s)
  20. }
  21. img_metas = DetDataSample()
  22. img_metas.set_metainfo(metainfo)
  23. batch_data_samples = []
  24. batch_data_samples.append(img_metas)
  25. configs = [
  26. get_detector_cfg(
  27. 'deformable_detr/deformable-detr_r50_16xb2-50e_coco.py'),
  28. get_detector_cfg(
  29. 'deformable_detr/deformable-detr-refine_r50_16xb2-50e_coco.py' # noqa
  30. ),
  31. get_detector_cfg(
  32. 'deformable_detr/deformable-detr-refine-twostage_r50_16xb2-50e_coco.py' # noqa
  33. )
  34. ]
  35. for config in configs:
  36. model = MODELS.build(config)
  37. model.init_weights()
  38. random_image = torch.rand(1, 3, s, s)
  39. # Test that empty ground truth encourages the network to
  40. # predict background
  41. gt_instances = InstanceData()
  42. gt_instances.bboxes = torch.empty((0, 4))
  43. gt_instances.labels = torch.LongTensor([])
  44. img_metas.gt_instances = gt_instances
  45. batch_data_samples1 = []
  46. batch_data_samples1.append(img_metas)
  47. empty_gt_losses = model.loss(
  48. random_image, batch_data_samples=batch_data_samples1)
  49. # When there is no truth, the cls loss should be nonzero but there
  50. # should be no box loss.
  51. for key, loss in empty_gt_losses.items():
  52. if 'cls' in key:
  53. self.assertGreater(loss.item(), 0,
  54. 'cls loss should be non-zero')
  55. elif 'bbox' in key:
  56. self.assertEqual(
  57. loss.item(), 0,
  58. 'there should be no box loss when no ground true boxes'
  59. )
  60. elif 'iou' in key:
  61. self.assertEqual(
  62. loss.item(), 0,
  63. 'there should be no iou loss when no ground true boxes'
  64. )
  65. # When truth is non-empty then both cls and box loss should
  66. # be nonzero for random inputs
  67. gt_instances = InstanceData()
  68. gt_instances.bboxes = torch.Tensor(
  69. [[23.6667, 23.8757, 238.6326, 151.8874]])
  70. gt_instances.labels = torch.LongTensor([2])
  71. img_metas.gt_instances = gt_instances
  72. batch_data_samples2 = []
  73. batch_data_samples2.append(img_metas)
  74. one_gt_losses = model.loss(
  75. random_image, batch_data_samples=batch_data_samples2)
  76. for loss in one_gt_losses.values():
  77. self.assertGreater(
  78. loss.item(), 0,
  79. 'cls loss, or box loss, or iou loss should be non-zero')
  80. model.eval()
  81. # test _forward
  82. model._forward(
  83. random_image, batch_data_samples=batch_data_samples2)
  84. # test only predict
  85. model.predict(
  86. random_image,
  87. batch_data_samples=batch_data_samples2,
  88. rescale=True)