horizontal_boxes.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple, TypeVar, Union
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from torch import BoolTensor, Tensor
  7. from mmdet.structures.mask.structures import BitmapMasks, PolygonMasks
  8. from .base_boxes import BaseBoxes
  9. from .bbox_overlaps import bbox_overlaps
  10. from .box_type import register_box
  11. T = TypeVar('T')
  12. DeviceType = Union[str, torch.device]
  13. MaskType = Union[BitmapMasks, PolygonMasks]
  14. @register_box(name='hbox')
  15. class HorizontalBoxes(BaseBoxes):
  16. """The horizontal box class used in MMDetection by default.
  17. The ``box_dim`` of ``HorizontalBoxes`` is 4, which means the length of
  18. the last dimension of the data should be 4. Two modes of box data are
  19. supported in ``HorizontalBoxes``:
  20. - 'xyxy': Each row of data indicates (x1, y1, x2, y2), which are the
  21. coordinates of the left-top and right-bottom points.
  22. - 'cxcywh': Each row of data indicates (x, y, w, h), where (x, y) are the
  23. coordinates of the box centers and (w, h) are the width and height.
  24. ``HorizontalBoxes`` only restores 'xyxy' mode of data. If the the data is
  25. in 'cxcywh' mode, users need to input ``in_mode='cxcywh'`` and The code
  26. will convert the 'cxcywh' data to 'xyxy' automatically.
  27. Args:
  28. data (Tensor or np.ndarray or Sequence): The box data with shape of
  29. (..., 4).
  30. dtype (torch.dtype, Optional): data type of boxes. Defaults to None.
  31. device (str or torch.device, Optional): device of boxes.
  32. Default to None.
  33. clone (bool): Whether clone ``boxes`` or not. Defaults to True.
  34. mode (str, Optional): the mode of boxes. If it is 'cxcywh', the
  35. `data` will be converted to 'xyxy' mode. Defaults to None.
  36. """
  37. box_dim: int = 4
  38. def __init__(self,
  39. data: Union[Tensor, np.ndarray],
  40. dtype: torch.dtype = None,
  41. device: DeviceType = None,
  42. clone: bool = True,
  43. in_mode: Optional[str] = None) -> None:
  44. super().__init__(data=data, dtype=dtype, device=device, clone=clone)
  45. if isinstance(in_mode, str):
  46. if in_mode not in ('xyxy', 'cxcywh'):
  47. raise ValueError(f'Get invalid mode {in_mode}.')
  48. if in_mode == 'cxcywh':
  49. self.tensor = self.cxcywh_to_xyxy(self.tensor)
  50. @staticmethod
  51. def cxcywh_to_xyxy(boxes: Tensor) -> Tensor:
  52. """Convert box coordinates from (cx, cy, w, h) to (x1, y1, x2, y2).
  53. Args:
  54. boxes (Tensor): cxcywh boxes tensor with shape of (..., 4).
  55. Returns:
  56. Tensor: xyxy boxes tensor with shape of (..., 4).
  57. """
  58. ctr, wh = boxes.split((2, 2), dim=-1)
  59. return torch.cat([(ctr - wh / 2), (ctr + wh / 2)], dim=-1)
  60. @staticmethod
  61. def xyxy_to_cxcywh(boxes: Tensor) -> Tensor:
  62. """Convert box coordinates from (x1, y1, x2, y2) to (cx, cy, w, h).
  63. Args:
  64. boxes (Tensor): xyxy boxes tensor with shape of (..., 4).
  65. Returns:
  66. Tensor: cxcywh boxes tensor with shape of (..., 4).
  67. """
  68. xy1, xy2 = boxes.split((2, 2), dim=-1)
  69. return torch.cat([(xy2 + xy1) / 2, (xy2 - xy1)], dim=-1)
  70. @property
  71. def cxcywh(self) -> Tensor:
  72. """Return a tensor representing the cxcywh boxes."""
  73. return self.xyxy_to_cxcywh(self.tensor)
  74. @property
  75. def centers(self) -> Tensor:
  76. """Return a tensor representing the centers of boxes."""
  77. boxes = self.tensor
  78. return (boxes[..., :2] + boxes[..., 2:]) / 2
  79. @property
  80. def areas(self) -> Tensor:
  81. """Return a tensor representing the areas of boxes."""
  82. boxes = self.tensor
  83. return (boxes[..., 2] - boxes[..., 0]) * (
  84. boxes[..., 3] - boxes[..., 1])
  85. @property
  86. def widths(self) -> Tensor:
  87. """Return a tensor representing the widths of boxes."""
  88. boxes = self.tensor
  89. return boxes[..., 2] - boxes[..., 0]
  90. @property
  91. def heights(self) -> Tensor:
  92. """Return a tensor representing the heights of boxes."""
  93. boxes = self.tensor
  94. return boxes[..., 3] - boxes[..., 1]
  95. def flip_(self,
  96. img_shape: Tuple[int, int],
  97. direction: str = 'horizontal') -> None:
  98. """Flip boxes horizontally or vertically in-place.
  99. Args:
  100. img_shape (Tuple[int, int]): A tuple of image height and width.
  101. direction (str): Flip direction, options are "horizontal",
  102. "vertical" and "diagonal". Defaults to "horizontal"
  103. """
  104. assert direction in ['horizontal', 'vertical', 'diagonal']
  105. flipped = self.tensor
  106. boxes = flipped.clone()
  107. if direction == 'horizontal':
  108. flipped[..., 0] = img_shape[1] - boxes[..., 2]
  109. flipped[..., 2] = img_shape[1] - boxes[..., 0]
  110. elif direction == 'vertical':
  111. flipped[..., 1] = img_shape[0] - boxes[..., 3]
  112. flipped[..., 3] = img_shape[0] - boxes[..., 1]
  113. else:
  114. flipped[..., 0] = img_shape[1] - boxes[..., 2]
  115. flipped[..., 1] = img_shape[0] - boxes[..., 3]
  116. flipped[..., 2] = img_shape[1] - boxes[..., 0]
  117. flipped[..., 3] = img_shape[0] - boxes[..., 1]
  118. def translate_(self, distances: Tuple[float, float]) -> None:
  119. """Translate boxes in-place.
  120. Args:
  121. distances (Tuple[float, float]): translate distances. The first
  122. is horizontal distance and the second is vertical distance.
  123. """
  124. boxes = self.tensor
  125. assert len(distances) == 2
  126. self.tensor = boxes + boxes.new_tensor(distances).repeat(2)
  127. def clip_(self, img_shape: Tuple[int, int]) -> None:
  128. """Clip boxes according to the image shape in-place.
  129. Args:
  130. img_shape (Tuple[int, int]): A tuple of image height and width.
  131. """
  132. boxes = self.tensor
  133. boxes[..., 0::2] = boxes[..., 0::2].clamp(0, img_shape[1])
  134. boxes[..., 1::2] = boxes[..., 1::2].clamp(0, img_shape[0])
  135. def rotate_(self, center: Tuple[float, float], angle: float) -> None:
  136. """Rotate all boxes in-place.
  137. Args:
  138. center (Tuple[float, float]): Rotation origin.
  139. angle (float): Rotation angle represented in degrees. Positive
  140. values mean clockwise rotation.
  141. """
  142. boxes = self.tensor
  143. rotation_matrix = boxes.new_tensor(
  144. cv2.getRotationMatrix2D(center, -angle, 1))
  145. corners = self.hbox2corner(boxes)
  146. corners = torch.cat(
  147. [corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1)
  148. corners_T = torch.transpose(corners, -1, -2)
  149. corners_T = torch.matmul(rotation_matrix, corners_T)
  150. corners = torch.transpose(corners_T, -1, -2)
  151. self.tensor = self.corner2hbox(corners)
  152. def project_(self, homography_matrix: Union[Tensor, np.ndarray]) -> None:
  153. """Geometric transformat boxes in-place.
  154. Args:
  155. homography_matrix (Tensor or np.ndarray]):
  156. Shape (3, 3) for geometric transformation.
  157. """
  158. boxes = self.tensor
  159. if isinstance(homography_matrix, np.ndarray):
  160. homography_matrix = boxes.new_tensor(homography_matrix)
  161. corners = self.hbox2corner(boxes)
  162. corners = torch.cat(
  163. [corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1)
  164. corners_T = torch.transpose(corners, -1, -2)
  165. corners_T = torch.matmul(homography_matrix, corners_T)
  166. corners = torch.transpose(corners_T, -1, -2)
  167. # Convert to homogeneous coordinates by normalization
  168. corners = corners[..., :2] / corners[..., 2:3]
  169. self.tensor = self.corner2hbox(corners)
  170. @staticmethod
  171. def hbox2corner(boxes: Tensor) -> Tensor:
  172. """Convert box coordinates from (x1, y1, x2, y2) to corners ((x1, y1),
  173. (x2, y1), (x1, y2), (x2, y2)).
  174. Args:
  175. boxes (Tensor): Horizontal box tensor with shape of (..., 4).
  176. Returns:
  177. Tensor: Corner tensor with shape of (..., 4, 2).
  178. """
  179. x1, y1, x2, y2 = torch.split(boxes, 1, dim=-1)
  180. corners = torch.cat([x1, y1, x2, y1, x1, y2, x2, y2], dim=-1)
  181. return corners.reshape(*corners.shape[:-1], 4, 2)
  182. @staticmethod
  183. def corner2hbox(corners: Tensor) -> Tensor:
  184. """Convert box coordinates from corners ((x1, y1), (x2, y1), (x1, y2),
  185. (x2, y2)) to (x1, y1, x2, y2).
  186. Args:
  187. corners (Tensor): Corner tensor with shape of (..., 4, 2).
  188. Returns:
  189. Tensor: Horizontal box tensor with shape of (..., 4).
  190. """
  191. if corners.numel() == 0:
  192. return corners.new_zeros((0, 4))
  193. min_xy = corners.min(dim=-2)[0]
  194. max_xy = corners.max(dim=-2)[0]
  195. return torch.cat([min_xy, max_xy], dim=-1)
  196. def rescale_(self, scale_factor: Tuple[float, float]) -> None:
  197. """Rescale boxes w.r.t. rescale_factor in-place.
  198. Note:
  199. Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes
  200. w.r.t ``scale_facotr``. The difference is that ``resize_`` only
  201. changes the width and the height of boxes, but ``rescale_`` also
  202. rescales the box centers simultaneously.
  203. Args:
  204. scale_factor (Tuple[float, float]): factors for scaling boxes.
  205. The length should be 2.
  206. """
  207. boxes = self.tensor
  208. assert len(scale_factor) == 2
  209. scale_factor = boxes.new_tensor(scale_factor).repeat(2)
  210. self.tensor = boxes * scale_factor
  211. def resize_(self, scale_factor: Tuple[float, float]) -> None:
  212. """Resize the box width and height w.r.t scale_factor in-place.
  213. Note:
  214. Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes
  215. w.r.t ``scale_facotr``. The difference is that ``resize_`` only
  216. changes the width and the height of boxes, but ``rescale_`` also
  217. rescales the box centers simultaneously.
  218. Args:
  219. scale_factor (Tuple[float, float]): factors for scaling box
  220. shapes. The length should be 2.
  221. """
  222. boxes = self.tensor
  223. assert len(scale_factor) == 2
  224. ctrs = (boxes[..., 2:] + boxes[..., :2]) / 2
  225. wh = boxes[..., 2:] - boxes[..., :2]
  226. scale_factor = boxes.new_tensor(scale_factor)
  227. wh = wh * scale_factor
  228. xy1 = ctrs - 0.5 * wh
  229. xy2 = ctrs + 0.5 * wh
  230. self.tensor = torch.cat([xy1, xy2], dim=-1)
  231. def is_inside(self,
  232. img_shape: Tuple[int, int],
  233. all_inside: bool = False,
  234. allowed_border: int = 0) -> BoolTensor:
  235. """Find boxes inside the image.
  236. Args:
  237. img_shape (Tuple[int, int]): A tuple of image height and width.
  238. all_inside (bool): Whether the boxes are all inside the image or
  239. part inside the image. Defaults to False.
  240. allowed_border (int): Boxes that extend beyond the image shape
  241. boundary by more than ``allowed_border`` are considered
  242. "outside" Defaults to 0.
  243. Returns:
  244. BoolTensor: A BoolTensor indicating whether the box is inside
  245. the image. Assuming the original boxes have shape (m, n, 4),
  246. the output has shape (m, n).
  247. """
  248. img_h, img_w = img_shape
  249. boxes = self.tensor
  250. if all_inside:
  251. return (boxes[:, 0] >= -allowed_border) & \
  252. (boxes[:, 1] >= -allowed_border) & \
  253. (boxes[:, 2] < img_w + allowed_border) & \
  254. (boxes[:, 3] < img_h + allowed_border)
  255. else:
  256. return (boxes[..., 0] < img_w + allowed_border) & \
  257. (boxes[..., 1] < img_h + allowed_border) & \
  258. (boxes[..., 2] > -allowed_border) & \
  259. (boxes[..., 3] > -allowed_border)
  260. def find_inside_points(self,
  261. points: Tensor,
  262. is_aligned: bool = False) -> BoolTensor:
  263. """Find inside box points. Boxes dimension must be 2.
  264. Args:
  265. points (Tensor): Points coordinates. Has shape of (m, 2).
  266. is_aligned (bool): Whether ``points`` has been aligned with boxes
  267. or not. If True, the length of boxes and ``points`` should be
  268. the same. Defaults to False.
  269. Returns:
  270. BoolTensor: A BoolTensor indicating whether a point is inside
  271. boxes. Assuming the boxes has shape of (n, 4), if ``is_aligned``
  272. is False. The index has shape of (m, n). If ``is_aligned`` is
  273. True, m should be equal to n and the index has shape of (m, ).
  274. """
  275. boxes = self.tensor
  276. assert boxes.dim() == 2, 'boxes dimension must be 2.'
  277. if not is_aligned:
  278. boxes = boxes[None, :, :]
  279. points = points[:, None, :]
  280. else:
  281. assert boxes.size(0) == points.size(0)
  282. x_min, y_min, x_max, y_max = boxes.unbind(dim=-1)
  283. return (points[..., 0] >= x_min) & (points[..., 0] <= x_max) & \
  284. (points[..., 1] >= y_min) & (points[..., 1] <= y_max)
  285. def create_masks(self, img_shape: Tuple[int, int]) -> BitmapMasks:
  286. """
  287. Args:
  288. img_shape (Tuple[int, int]): A tuple of image height and width.
  289. Returns:
  290. :obj:`BitmapMasks`: Converted masks
  291. """
  292. img_h, img_w = img_shape
  293. boxes = self.tensor
  294. xmin, ymin = boxes[:, 0:1], boxes[:, 1:2]
  295. xmax, ymax = boxes[:, 2:3], boxes[:, 3:4]
  296. gt_masks = np.zeros((len(boxes), img_h, img_w), dtype=np.uint8)
  297. for i in range(len(boxes)):
  298. gt_masks[i,
  299. int(ymin[i]):int(ymax[i]),
  300. int(xmin[i]):int(xmax[i])] = 1
  301. return BitmapMasks(gt_masks, img_h, img_w)
  302. @staticmethod
  303. def overlaps(boxes1: BaseBoxes,
  304. boxes2: BaseBoxes,
  305. mode: str = 'iou',
  306. is_aligned: bool = False,
  307. eps: float = 1e-6) -> Tensor:
  308. """Calculate overlap between two set of boxes with their types
  309. converted to ``HorizontalBoxes``.
  310. Args:
  311. boxes1 (:obj:`BaseBoxes`): BaseBoxes with shape of (m, box_dim)
  312. or empty.
  313. boxes2 (:obj:`BaseBoxes`): BaseBoxes with shape of (n, box_dim)
  314. or empty.
  315. mode (str): "iou" (intersection over union), "iof" (intersection
  316. over foreground). Defaults to "iou".
  317. is_aligned (bool): If True, then m and n must be equal. Defaults
  318. to False.
  319. eps (float): A value added to the denominator for numerical
  320. stability. Defaults to 1e-6.
  321. Returns:
  322. Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,)
  323. """
  324. boxes1 = boxes1.convert_to('hbox')
  325. boxes2 = boxes2.convert_to('hbox')
  326. return bbox_overlaps(
  327. boxes1.tensor,
  328. boxes2.tensor,
  329. mode=mode,
  330. is_aligned=is_aligned,
  331. eps=eps)
  332. @staticmethod
  333. def from_instance_masks(masks: MaskType) -> 'HorizontalBoxes':
  334. """Create horizontal boxes from instance masks.
  335. Args:
  336. masks (:obj:`BitmapMasks` or :obj:`PolygonMasks`): BitmapMasks or
  337. PolygonMasks instance with length of n.
  338. Returns:
  339. :obj:`HorizontalBoxes`: Converted boxes with shape of (n, 4).
  340. """
  341. num_masks = len(masks)
  342. boxes = np.zeros((num_masks, 4), dtype=np.float32)
  343. if isinstance(masks, BitmapMasks):
  344. x_any = masks.masks.any(axis=1)
  345. y_any = masks.masks.any(axis=2)
  346. for idx in range(num_masks):
  347. x = np.where(x_any[idx, :])[0]
  348. y = np.where(y_any[idx, :])[0]
  349. if len(x) > 0 and len(y) > 0:
  350. # use +1 for x_max and y_max so that the right and bottom
  351. # boundary of instance masks are fully included by the box
  352. boxes[idx, :] = np.array(
  353. [x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=np.float32)
  354. elif isinstance(masks, PolygonMasks):
  355. for idx, poly_per_obj in enumerate(masks.masks):
  356. # simply use a number that is big enough for comparison with
  357. # coordinates
  358. xy_min = np.array([masks.width * 2, masks.height * 2],
  359. dtype=np.float32)
  360. xy_max = np.zeros(2, dtype=np.float32)
  361. for p in poly_per_obj:
  362. xy = np.array(p).reshape(-1, 2).astype(np.float32)
  363. xy_min = np.minimum(xy_min, np.min(xy, axis=0))
  364. xy_max = np.maximum(xy_max, np.max(xy, axis=0))
  365. boxes[idx, :2] = xy_min
  366. boxes[idx, 2:] = xy_max
  367. else:
  368. raise TypeError(
  369. '`masks` must be `BitmapMasks` or `PolygonMasks`, '
  370. f'but got {type(masks)}.')
  371. return HorizontalBoxes(boxes)