test_iou2d_calculator.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import pytest
  4. import torch
  5. from mmdet.evaluation import bbox_overlaps as recall_overlaps
  6. from mmdet.models.task_modules import BboxOverlaps2D
  7. from mmdet.structures.bbox import bbox_overlaps
  8. def test_bbox_overlaps_2d(eps=1e-7):
  9. def _construct_bbox(num_bbox=None):
  10. img_h = int(np.random.randint(3, 1000))
  11. img_w = int(np.random.randint(3, 1000))
  12. if num_bbox is None:
  13. num_bbox = np.random.randint(1, 10)
  14. x1y1 = torch.rand((num_bbox, 2))
  15. x2y2 = torch.max(torch.rand((num_bbox, 2)), x1y1)
  16. bboxes = torch.cat((x1y1, x2y2), -1)
  17. bboxes[:, 0::2] *= img_w
  18. bboxes[:, 1::2] *= img_h
  19. return bboxes, num_bbox
  20. # is_aligned is True, bboxes.size(-1) == 5 (include score)
  21. self = BboxOverlaps2D()
  22. bboxes1, num_bbox = _construct_bbox()
  23. bboxes2, _ = _construct_bbox(num_bbox)
  24. bboxes1 = torch.cat((bboxes1, torch.rand((num_bbox, 1))), 1)
  25. bboxes2 = torch.cat((bboxes2, torch.rand((num_bbox, 1))), 1)
  26. gious = self(bboxes1, bboxes2, 'giou', True)
  27. assert gious.size() == (num_bbox, ), gious.size()
  28. assert torch.all(gious >= -1) and torch.all(gious <= 1)
  29. # is_aligned is True, bboxes1.size(-2) == 0
  30. bboxes1 = torch.empty((0, 4))
  31. bboxes2 = torch.empty((0, 4))
  32. gious = self(bboxes1, bboxes2, 'giou', True)
  33. assert gious.size() == (0, ), gious.size()
  34. assert torch.all(gious == torch.empty((0, )))
  35. assert torch.all(gious >= -1) and torch.all(gious <= 1)
  36. # is_aligned is True, and bboxes.ndims > 2
  37. bboxes1, num_bbox = _construct_bbox()
  38. bboxes2, _ = _construct_bbox(num_bbox)
  39. bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1)
  40. # test assertion when batch dim is not the same
  41. with pytest.raises(AssertionError):
  42. self(bboxes1, bboxes2.unsqueeze(0).repeat(3, 1, 1), 'giou', True)
  43. bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1)
  44. gious = self(bboxes1, bboxes2, 'giou', True)
  45. assert torch.all(gious >= -1) and torch.all(gious <= 1)
  46. assert gious.size() == (2, num_bbox)
  47. bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1, 1)
  48. bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1, 1)
  49. gious = self(bboxes1, bboxes2, 'giou', True)
  50. assert torch.all(gious >= -1) and torch.all(gious <= 1)
  51. assert gious.size() == (2, 2, num_bbox)
  52. # is_aligned is False
  53. bboxes1, num_bbox1 = _construct_bbox()
  54. bboxes2, num_bbox2 = _construct_bbox()
  55. gious = self(bboxes1, bboxes2, 'giou')
  56. assert torch.all(gious >= -1) and torch.all(gious <= 1)
  57. assert gious.size() == (num_bbox1, num_bbox2)
  58. # is_aligned is False, and bboxes.ndims > 2
  59. bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1)
  60. bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1)
  61. gious = self(bboxes1, bboxes2, 'giou')
  62. assert torch.all(gious >= -1) and torch.all(gious <= 1)
  63. assert gious.size() == (2, num_bbox1, num_bbox2)
  64. bboxes1 = bboxes1.unsqueeze(0)
  65. bboxes2 = bboxes2.unsqueeze(0)
  66. gious = self(bboxes1, bboxes2, 'giou')
  67. assert torch.all(gious >= -1) and torch.all(gious <= 1)
  68. assert gious.size() == (1, 2, num_bbox1, num_bbox2)
  69. # is_aligned is False, bboxes1.size(-2) == 0
  70. gious = self(torch.empty(1, 2, 0, 4), bboxes2, 'giou')
  71. assert torch.all(gious == torch.empty(1, 2, 0, bboxes2.size(-2)))
  72. assert torch.all(gious >= -1) and torch.all(gious <= 1)
  73. # test allclose between bbox_overlaps and the original official
  74. # implementation.
  75. bboxes1 = torch.FloatTensor([
  76. [0, 0, 10, 10],
  77. [10, 10, 20, 20],
  78. [32, 32, 38, 42],
  79. ])
  80. bboxes2 = torch.FloatTensor([
  81. [0, 0, 10, 20],
  82. [0, 10, 10, 19],
  83. [10, 10, 20, 20],
  84. ])
  85. gious = bbox_overlaps(bboxes1, bboxes2, 'giou', is_aligned=True, eps=eps)
  86. gious = gious.numpy().round(4)
  87. # the gt is got with four decimal precision.
  88. expected_gious = np.array([0.5000, -0.0500, -0.8214])
  89. assert np.allclose(gious, expected_gious, rtol=0, atol=eps)
  90. # test mode 'iof'
  91. ious = bbox_overlaps(bboxes1, bboxes2, 'iof', is_aligned=True, eps=eps)
  92. assert torch.all(ious >= -1) and torch.all(ious <= 1)
  93. assert ious.size() == (bboxes1.size(0), )
  94. ious = bbox_overlaps(bboxes1, bboxes2, 'iof', eps=eps)
  95. assert torch.all(ious >= -1) and torch.all(ious <= 1)
  96. assert ious.size() == (bboxes1.size(0), bboxes2.size(0))
  97. def test_voc_recall_overlaps():
  98. def _construct_bbox(num_bbox=None):
  99. img_h = int(np.random.randint(3, 1000))
  100. img_w = int(np.random.randint(3, 1000))
  101. if num_bbox is None:
  102. num_bbox = np.random.randint(1, 10)
  103. x1y1 = torch.rand((num_bbox, 2))
  104. x2y2 = torch.max(torch.rand((num_bbox, 2)), x1y1)
  105. bboxes = torch.cat((x1y1, x2y2), -1)
  106. bboxes[:, 0::2] *= img_w
  107. bboxes[:, 1::2] *= img_h
  108. return bboxes.numpy(), num_bbox
  109. bboxes1, num_bbox = _construct_bbox()
  110. bboxes2, _ = _construct_bbox(num_bbox)
  111. ious = recall_overlaps(
  112. bboxes1, bboxes2, 'iou', use_legacy_coordinate=False)
  113. assert ious.shape == (num_bbox, num_bbox)
  114. assert np.all(ious >= -1) and np.all(ious <= 1)
  115. ious = recall_overlaps(bboxes1, bboxes2, 'iou', use_legacy_coordinate=True)
  116. assert ious.shape == (num_bbox, num_bbox)
  117. assert np.all(ious >= -1) and np.all(ious <= 1)