iou2d_calculator.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmdet.registry import TASK_UTILS
  4. from mmdet.structures.bbox import bbox_overlaps, get_box_tensor
  5. def cast_tensor_type(x, scale=1., dtype=None):
  6. if dtype == 'fp16':
  7. # scale is for preventing overflows
  8. x = (x / scale).half()
  9. return x
  10. @TASK_UTILS.register_module()
  11. class BboxOverlaps2D:
  12. """2D Overlaps (e.g. IoUs, GIoUs) Calculator."""
  13. def __init__(self, scale=1., dtype=None):
  14. self.scale = scale
  15. self.dtype = dtype
  16. def __call__(self, bboxes1, bboxes2, mode='iou', is_aligned=False):
  17. """Calculate IoU between 2D bboxes.
  18. Args:
  19. bboxes1 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4)
  20. in <x1, y1, x2, y2> format, or shape (m, 5) in <x1, y1, x2,
  21. y2, score> format.
  22. bboxes2 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4)
  23. in <x1, y1, x2, y2> format, shape (m, 5) in <x1, y1, x2, y2,
  24. score> format, or be empty. If ``is_aligned `` is ``True``,
  25. then m and n must be equal.
  26. mode (str): "iou" (intersection over union), "iof" (intersection
  27. over foreground), or "giou" (generalized intersection over
  28. union).
  29. is_aligned (bool, optional): If True, then m and n must be equal.
  30. Default False.
  31. Returns:
  32. Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
  33. """
  34. bboxes1 = get_box_tensor(bboxes1)
  35. bboxes2 = get_box_tensor(bboxes2)
  36. assert bboxes1.size(-1) in [0, 4, 5]
  37. assert bboxes2.size(-1) in [0, 4, 5]
  38. if bboxes2.size(-1) == 5:
  39. bboxes2 = bboxes2[..., :4]
  40. if bboxes1.size(-1) == 5:
  41. bboxes1 = bboxes1[..., :4]
  42. if self.dtype == 'fp16':
  43. # change tensor type to save cpu and cuda memory and keep speed
  44. bboxes1 = cast_tensor_type(bboxes1, self.scale, self.dtype)
  45. bboxes2 = cast_tensor_type(bboxes2, self.scale, self.dtype)
  46. overlaps = bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)
  47. if not overlaps.is_cuda and overlaps.dtype == torch.float16:
  48. # resume cpu float32
  49. overlaps = overlaps.float()
  50. return overlaps
  51. return bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)
  52. def __repr__(self):
  53. """str: a string describing the module"""
  54. repr_str = self.__class__.__name__ + f'(' \
  55. f'scale={self.scale}, dtype={self.dtype})'
  56. return repr_str