legacy_delta_xywh_bbox_coder.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Sequence, Union
  3. import numpy as np
  4. import torch
  5. from torch import Tensor
  6. from mmdet.registry import TASK_UTILS
  7. from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor
  8. from .base_bbox_coder import BaseBBoxCoder
  9. @TASK_UTILS.register_module()
  10. class LegacyDeltaXYWHBBoxCoder(BaseBBoxCoder):
  11. """Legacy Delta XYWH BBox coder used in MMDet V1.x.
  12. Following the practice in R-CNN [1]_, this coder encodes bbox (x1, y1, x2,
  13. y2) into delta (dx, dy, dw, dh) and decodes delta (dx, dy, dw, dh)
  14. back to original bbox (x1, y1, x2, y2).
  15. Note:
  16. The main difference between :class`LegacyDeltaXYWHBBoxCoder` and
  17. :class:`DeltaXYWHBBoxCoder` is whether ``+ 1`` is used during width and
  18. height calculation. We suggest to only use this coder when testing with
  19. MMDet V1.x models.
  20. References:
  21. .. [1] https://arxiv.org/abs/1311.2524
  22. Args:
  23. target_means (Sequence[float]): denormalizing means of target for
  24. delta coordinates
  25. target_stds (Sequence[float]): denormalizing standard deviation of
  26. target for delta coordinates
  27. """
  28. def __init__(self,
  29. target_means: Sequence[float] = (0., 0., 0., 0.),
  30. target_stds: Sequence[float] = (1., 1., 1., 1.),
  31. **kwargs) -> None:
  32. super().__init__(**kwargs)
  33. self.means = target_means
  34. self.stds = target_stds
  35. def encode(self, bboxes: Union[Tensor, BaseBoxes],
  36. gt_bboxes: Union[Tensor, BaseBoxes]) -> Tensor:
  37. """Get box regression transformation deltas that can be used to
  38. transform the ``bboxes`` into the ``gt_bboxes``.
  39. Args:
  40. bboxes (torch.Tensor or :obj:`BaseBoxes`): source boxes,
  41. e.g., object proposals.
  42. gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): target of the
  43. transformation, e.g., ground-truth boxes.
  44. Returns:
  45. torch.Tensor: Box transformation deltas
  46. """
  47. bboxes = get_box_tensor(bboxes)
  48. gt_bboxes = get_box_tensor(gt_bboxes)
  49. assert bboxes.size(0) == gt_bboxes.size(0)
  50. assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
  51. encoded_bboxes = legacy_bbox2delta(bboxes, gt_bboxes, self.means,
  52. self.stds)
  53. return encoded_bboxes
  54. def decode(
  55. self,
  56. bboxes: Union[Tensor, BaseBoxes],
  57. pred_bboxes: Tensor,
  58. max_shape: Optional[Union[Sequence[int], Tensor,
  59. Sequence[Sequence[int]]]] = None,
  60. wh_ratio_clip: Optional[float] = 16 / 1000
  61. ) -> Union[Tensor, BaseBoxes]:
  62. """Apply transformation `pred_bboxes` to `boxes`.
  63. Args:
  64. boxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes.
  65. pred_bboxes (torch.Tensor): Encoded boxes with shape
  66. max_shape (tuple[int], optional): Maximum shape of boxes.
  67. Defaults to None.
  68. wh_ratio_clip (float, optional): The allowed ratio between
  69. width and height.
  70. Returns:
  71. Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes.
  72. """
  73. bboxes = get_box_tensor(bboxes)
  74. assert pred_bboxes.size(0) == bboxes.size(0)
  75. decoded_bboxes = legacy_delta2bbox(bboxes, pred_bboxes, self.means,
  76. self.stds, max_shape, wh_ratio_clip)
  77. if self.use_box_type:
  78. assert decoded_bboxes.size(-1) == 4, \
  79. ('Cannot warp decoded boxes with box type when decoded boxes'
  80. 'have shape of (N, num_classes * 4)')
  81. decoded_bboxes = HorizontalBoxes(decoded_bboxes)
  82. return decoded_bboxes
  83. def legacy_bbox2delta(
  84. proposals: Tensor,
  85. gt: Tensor,
  86. means: Sequence[float] = (0., 0., 0., 0.),
  87. stds: Sequence[float] = (1., 1., 1., 1.)
  88. ) -> Tensor:
  89. """Compute deltas of proposals w.r.t. gt in the MMDet V1.x manner.
  90. We usually compute the deltas of x, y, w, h of proposals w.r.t ground
  91. truth bboxes to get regression target.
  92. This is the inverse function of `delta2bbox()`
  93. Args:
  94. proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
  95. gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
  96. means (Sequence[float]): Denormalizing means for delta coordinates
  97. stds (Sequence[float]): Denormalizing standard deviation for delta
  98. coordinates
  99. Returns:
  100. Tensor: deltas with shape (N, 4), where columns represent dx, dy,
  101. dw, dh.
  102. """
  103. assert proposals.size() == gt.size()
  104. proposals = proposals.float()
  105. gt = gt.float()
  106. px = (proposals[..., 0] + proposals[..., 2]) * 0.5
  107. py = (proposals[..., 1] + proposals[..., 3]) * 0.5
  108. pw = proposals[..., 2] - proposals[..., 0] + 1.0
  109. ph = proposals[..., 3] - proposals[..., 1] + 1.0
  110. gx = (gt[..., 0] + gt[..., 2]) * 0.5
  111. gy = (gt[..., 1] + gt[..., 3]) * 0.5
  112. gw = gt[..., 2] - gt[..., 0] + 1.0
  113. gh = gt[..., 3] - gt[..., 1] + 1.0
  114. dx = (gx - px) / pw
  115. dy = (gy - py) / ph
  116. dw = torch.log(gw / pw)
  117. dh = torch.log(gh / ph)
  118. deltas = torch.stack([dx, dy, dw, dh], dim=-1)
  119. means = deltas.new_tensor(means).unsqueeze(0)
  120. stds = deltas.new_tensor(stds).unsqueeze(0)
  121. deltas = deltas.sub_(means).div_(stds)
  122. return deltas
  123. def legacy_delta2bbox(rois: Tensor,
  124. deltas: Tensor,
  125. means: Sequence[float] = (0., 0., 0., 0.),
  126. stds: Sequence[float] = (1., 1., 1., 1.),
  127. max_shape: Optional[
  128. Union[Sequence[int], Tensor,
  129. Sequence[Sequence[int]]]] = None,
  130. wh_ratio_clip: float = 16 / 1000) -> Tensor:
  131. """Apply deltas to shift/scale base boxes in the MMDet V1.x manner.
  132. Typically the rois are anchor or proposed bounding boxes and the deltas are
  133. network outputs used to shift/scale those boxes.
  134. This is the inverse function of `bbox2delta()`
  135. Args:
  136. rois (Tensor): Boxes to be transformed. Has shape (N, 4)
  137. deltas (Tensor): Encoded offsets with respect to each roi.
  138. Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when
  139. rois is a grid of anchors. Offset encoding follows [1]_.
  140. means (Sequence[float]): Denormalizing means for delta coordinates
  141. stds (Sequence[float]): Denormalizing standard deviation for delta
  142. coordinates
  143. max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
  144. wh_ratio_clip (float): Maximum aspect ratio for boxes.
  145. Returns:
  146. Tensor: Boxes with shape (N, 4), where columns represent
  147. tl_x, tl_y, br_x, br_y.
  148. References:
  149. .. [1] https://arxiv.org/abs/1311.2524
  150. Example:
  151. >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
  152. >>> [ 0., 0., 1., 1.],
  153. >>> [ 0., 0., 1., 1.],
  154. >>> [ 5., 5., 5., 5.]])
  155. >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
  156. >>> [ 1., 1., 1., 1.],
  157. >>> [ 0., 0., 2., -1.],
  158. >>> [ 0.7, -1.9, -0.5, 0.3]])
  159. >>> legacy_delta2bbox(rois, deltas, max_shape=(32, 32))
  160. tensor([[0.0000, 0.0000, 1.5000, 1.5000],
  161. [0.0000, 0.0000, 5.2183, 5.2183],
  162. [0.0000, 0.1321, 7.8891, 0.8679],
  163. [5.3967, 2.4251, 6.0033, 3.7749]])
  164. """
  165. means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
  166. stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
  167. denorm_deltas = deltas * stds + means
  168. dx = denorm_deltas[:, 0::4]
  169. dy = denorm_deltas[:, 1::4]
  170. dw = denorm_deltas[:, 2::4]
  171. dh = denorm_deltas[:, 3::4]
  172. max_ratio = np.abs(np.log(wh_ratio_clip))
  173. dw = dw.clamp(min=-max_ratio, max=max_ratio)
  174. dh = dh.clamp(min=-max_ratio, max=max_ratio)
  175. # Compute center of each roi
  176. px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
  177. py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
  178. # Compute width/height of each roi
  179. pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw)
  180. ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh)
  181. # Use exp(network energy) to enlarge/shrink each roi
  182. gw = pw * dw.exp()
  183. gh = ph * dh.exp()
  184. # Use network energy to shift the center of each roi
  185. gx = px + pw * dx
  186. gy = py + ph * dy
  187. # Convert center-xy/width/height to top-left, bottom-right
  188. # The true legacy box coder should +- 0.5 here.
  189. # However, current implementation improves the performance when testing
  190. # the models trained in MMDetection 1.X (~0.5 bbox AP, 0.2 mask AP)
  191. x1 = gx - gw * 0.5
  192. y1 = gy - gh * 0.5
  193. x2 = gx + gw * 0.5
  194. y2 = gy + gh * 0.5
  195. if max_shape is not None:
  196. x1 = x1.clamp(min=0, max=max_shape[1] - 1)
  197. y1 = y1.clamp(min=0, max=max_shape[0] - 1)
  198. x2 = x2.clamp(min=0, max=max_shape[1] - 1)
  199. y2 = y2.clamp(min=0, max=max_shape[0] - 1)
  200. bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)
  201. return bboxes