base_boxes.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod, abstractproperty, abstractstaticmethod
  3. from typing import List, Optional, Sequence, Tuple, Type, TypeVar, Union
  4. import numpy as np
  5. import torch
  6. from torch import BoolTensor, Tensor
  7. from mmdet.structures.mask.structures import BitmapMasks, PolygonMasks
  8. T = TypeVar('T')
  9. DeviceType = Union[str, torch.device]
  10. IndexType = Union[slice, int, list, torch.LongTensor, torch.cuda.LongTensor,
  11. torch.BoolTensor, torch.cuda.BoolTensor, np.ndarray]
  12. MaskType = Union[BitmapMasks, PolygonMasks]
  13. class BaseBoxes(metaclass=ABCMeta):
  14. """The base class for 2D box types.
  15. The functions of ``BaseBoxes`` lie in three fields:
  16. - Verify the boxes shape.
  17. - Support tensor-like operations.
  18. - Define abstract functions for 2D boxes.
  19. In ``__init__`` , ``BaseBoxes`` verifies the validity of the data shape
  20. w.r.t ``box_dim``. The tensor with the dimension >= 2 and the length
  21. of the last dimension being ``box_dim`` will be regarded as valid.
  22. ``BaseBoxes`` will restore them at the field ``tensor``. It's necessary
  23. to override ``box_dim`` in subclass to guarantee the data shape is
  24. correct.
  25. There are many basic tensor-like functions implemented in ``BaseBoxes``.
  26. In most cases, users can operate ``BaseBoxes`` instance like a normal
  27. tensor. To protect the validity of data shape, All tensor-like functions
  28. cannot modify the last dimension of ``self.tensor``.
  29. When creating a new box type, users need to inherit from ``BaseBoxes``
  30. and override abstract methods and specify the ``box_dim``. Then, register
  31. the new box type by using the decorator ``register_box_type``.
  32. Args:
  33. data (Tensor or np.ndarray or Sequence): The box data with shape
  34. (..., box_dim).
  35. dtype (torch.dtype, Optional): data type of boxes. Defaults to None.
  36. device (str or torch.device, Optional): device of boxes.
  37. Default to None.
  38. clone (bool): Whether clone ``boxes`` or not. Defaults to True.
  39. """
  40. # Used to verify the last dimension length
  41. # Should override it in subclass.
  42. box_dim: int = 0
  43. def __init__(self,
  44. data: Union[Tensor, np.ndarray, Sequence],
  45. dtype: Optional[torch.dtype] = None,
  46. device: Optional[DeviceType] = None,
  47. clone: bool = True) -> None:
  48. if isinstance(data, (np.ndarray, Tensor, Sequence)):
  49. data = torch.as_tensor(data)
  50. else:
  51. raise TypeError('boxes should be Tensor, ndarray, or Sequence, ',
  52. f'but got {type(data)}')
  53. if device is not None or dtype is not None:
  54. data = data.to(dtype=dtype, device=device)
  55. # Clone the data to avoid potential bugs
  56. if clone:
  57. data = data.clone()
  58. # handle the empty input like []
  59. if data.numel() == 0:
  60. data = data.reshape((-1, self.box_dim))
  61. assert data.dim() >= 2 and data.size(-1) == self.box_dim, \
  62. ('The boxes dimension must >= 2 and the length of the last '
  63. f'dimension must be {self.box_dim}, but got boxes with '
  64. f'shape {data.shape}.')
  65. self.tensor = data
  66. def convert_to(self, dst_type: Union[str, type]) -> 'BaseBoxes':
  67. """Convert self to another box type.
  68. Args:
  69. dst_type (str or type): destination box type.
  70. Returns:
  71. :obj:`BaseBoxes`: destination box type object .
  72. """
  73. from .box_type import convert_box_type
  74. return convert_box_type(self, dst_type=dst_type)
  75. def empty_boxes(self: T,
  76. dtype: Optional[torch.dtype] = None,
  77. device: Optional[DeviceType] = None) -> T:
  78. """Create empty box.
  79. Args:
  80. dtype (torch.dtype, Optional): data type of boxes.
  81. device (str or torch.device, Optional): device of boxes.
  82. Returns:
  83. T: empty boxes with shape of (0, box_dim).
  84. """
  85. empty_box = self.tensor.new_zeros(
  86. 0, self.box_dim, dtype=dtype, device=device)
  87. return type(self)(empty_box, clone=False)
  88. def fake_boxes(self: T,
  89. sizes: Tuple[int],
  90. fill: float = 0,
  91. dtype: Optional[torch.dtype] = None,
  92. device: Optional[DeviceType] = None) -> T:
  93. """Create fake boxes with specific sizes and fill values.
  94. Args:
  95. sizes (Tuple[int]): The size of fake boxes. The last value must
  96. be equal with ``self.box_dim``.
  97. fill (float): filling value. Defaults to 0.
  98. dtype (torch.dtype, Optional): data type of boxes.
  99. device (str or torch.device, Optional): device of boxes.
  100. Returns:
  101. T: Fake boxes with shape of ``sizes``.
  102. """
  103. fake_boxes = self.tensor.new_full(
  104. sizes, fill, dtype=dtype, device=device)
  105. return type(self)(fake_boxes, clone=False)
  106. def __getitem__(self: T, index: IndexType) -> T:
  107. """Rewrite getitem to protect the last dimension shape."""
  108. boxes = self.tensor
  109. if isinstance(index, np.ndarray):
  110. index = torch.as_tensor(index, device=self.device)
  111. if isinstance(index, Tensor) and index.dtype == torch.bool:
  112. assert index.dim() < boxes.dim()
  113. elif isinstance(index, tuple):
  114. assert len(index) < boxes.dim()
  115. # `Ellipsis`(...) is commonly used in index like [None, ...].
  116. # When `Ellipsis` is in index, it must be the last item.
  117. if Ellipsis in index:
  118. assert index[-1] is Ellipsis
  119. boxes = boxes[index]
  120. if boxes.dim() == 1:
  121. boxes = boxes.reshape(1, -1)
  122. return type(self)(boxes, clone=False)
  123. def __setitem__(self: T, index: IndexType, values: Union[Tensor, T]) -> T:
  124. """Rewrite setitem to protect the last dimension shape."""
  125. assert type(values) is type(self), \
  126. 'The value to be set must be the same box type as self'
  127. values = values.tensor
  128. if isinstance(index, np.ndarray):
  129. index = torch.as_tensor(index, device=self.device)
  130. if isinstance(index, Tensor) and index.dtype == torch.bool:
  131. assert index.dim() < self.tensor.dim()
  132. elif isinstance(index, tuple):
  133. assert len(index) < self.tensor.dim()
  134. # `Ellipsis`(...) is commonly used in index like [None, ...].
  135. # When `Ellipsis` is in index, it must be the last item.
  136. if Ellipsis in index:
  137. assert index[-1] is Ellipsis
  138. self.tensor[index] = values
  139. def __len__(self) -> int:
  140. """Return the length of self.tensor first dimension."""
  141. return self.tensor.size(0)
  142. def __deepcopy__(self, memo):
  143. """Only clone the ``self.tensor`` when applying deepcopy."""
  144. cls = self.__class__
  145. other = cls.__new__(cls)
  146. memo[id(self)] = other
  147. other.tensor = self.tensor.clone()
  148. return other
  149. def __repr__(self) -> str:
  150. """Return a strings that describes the object."""
  151. return self.__class__.__name__ + '(\n' + str(self.tensor) + ')'
  152. def new_tensor(self, *args, **kwargs) -> Tensor:
  153. """Reload ``new_tensor`` from self.tensor."""
  154. return self.tensor.new_tensor(*args, **kwargs)
  155. def new_full(self, *args, **kwargs) -> Tensor:
  156. """Reload ``new_full`` from self.tensor."""
  157. return self.tensor.new_full(*args, **kwargs)
  158. def new_empty(self, *args, **kwargs) -> Tensor:
  159. """Reload ``new_empty`` from self.tensor."""
  160. return self.tensor.new_empty(*args, **kwargs)
  161. def new_ones(self, *args, **kwargs) -> Tensor:
  162. """Reload ``new_ones`` from self.tensor."""
  163. return self.tensor.new_ones(*args, **kwargs)
  164. def new_zeros(self, *args, **kwargs) -> Tensor:
  165. """Reload ``new_zeros`` from self.tensor."""
  166. return self.tensor.new_zeros(*args, **kwargs)
  167. def size(self, dim: Optional[int] = None) -> Union[int, torch.Size]:
  168. """Reload new_zeros from self.tensor."""
  169. # self.tensor.size(dim) cannot work when dim=None.
  170. return self.tensor.size() if dim is None else self.tensor.size(dim)
  171. def dim(self) -> int:
  172. """Reload ``dim`` from self.tensor."""
  173. return self.tensor.dim()
  174. @property
  175. def device(self) -> torch.device:
  176. """Reload ``device`` from self.tensor."""
  177. return self.tensor.device
  178. @property
  179. def dtype(self) -> torch.dtype:
  180. """Reload ``dtype`` from self.tensor."""
  181. return self.tensor.dtype
  182. @property
  183. def shape(self) -> torch.Size:
  184. return self.tensor.shape
  185. def numel(self) -> int:
  186. """Reload ``numel`` from self.tensor."""
  187. return self.tensor.numel()
  188. def numpy(self) -> np.ndarray:
  189. """Reload ``numpy`` from self.tensor."""
  190. return self.tensor.numpy()
  191. def to(self: T, *args, **kwargs) -> T:
  192. """Reload ``to`` from self.tensor."""
  193. return type(self)(self.tensor.to(*args, **kwargs), clone=False)
  194. def cpu(self: T) -> T:
  195. """Reload ``cpu`` from self.tensor."""
  196. return type(self)(self.tensor.cpu(), clone=False)
  197. def cuda(self: T, *args, **kwargs) -> T:
  198. """Reload ``cuda`` from self.tensor."""
  199. return type(self)(self.tensor.cuda(*args, **kwargs), clone=False)
  200. def clone(self: T) -> T:
  201. """Reload ``clone`` from self.tensor."""
  202. return type(self)(self.tensor)
  203. def detach(self: T) -> T:
  204. """Reload ``detach`` from self.tensor."""
  205. return type(self)(self.tensor.detach(), clone=False)
  206. def view(self: T, *shape: Tuple[int]) -> T:
  207. """Reload ``view`` from self.tensor."""
  208. return type(self)(self.tensor.view(shape), clone=False)
  209. def reshape(self: T, *shape: Tuple[int]) -> T:
  210. """Reload ``reshape`` from self.tensor."""
  211. return type(self)(self.tensor.reshape(shape), clone=False)
  212. def expand(self: T, *sizes: Tuple[int]) -> T:
  213. """Reload ``expand`` from self.tensor."""
  214. return type(self)(self.tensor.expand(sizes), clone=False)
  215. def repeat(self: T, *sizes: Tuple[int]) -> T:
  216. """Reload ``repeat`` from self.tensor."""
  217. return type(self)(self.tensor.repeat(sizes), clone=False)
  218. def transpose(self: T, dim0: int, dim1: int) -> T:
  219. """Reload ``transpose`` from self.tensor."""
  220. ndim = self.tensor.dim()
  221. assert dim0 != -1 and dim0 != ndim - 1
  222. assert dim1 != -1 and dim1 != ndim - 1
  223. return type(self)(self.tensor.transpose(dim0, dim1), clone=False)
  224. def permute(self: T, *dims: Tuple[int]) -> T:
  225. """Reload ``permute`` from self.tensor."""
  226. assert dims[-1] == -1 or dims[-1] == self.tensor.dim() - 1
  227. return type(self)(self.tensor.permute(dims), clone=False)
  228. def split(self: T,
  229. split_size_or_sections: Union[int, Sequence[int]],
  230. dim: int = 0) -> List[T]:
  231. """Reload ``split`` from self.tensor."""
  232. assert dim != -1 and dim != self.tensor.dim() - 1
  233. boxes_list = self.tensor.split(split_size_or_sections, dim=dim)
  234. return [type(self)(boxes, clone=False) for boxes in boxes_list]
  235. def chunk(self: T, chunks: int, dim: int = 0) -> List[T]:
  236. """Reload ``chunk`` from self.tensor."""
  237. assert dim != -1 and dim != self.tensor.dim() - 1
  238. boxes_list = self.tensor.chunk(chunks, dim=dim)
  239. return [type(self)(boxes, clone=False) for boxes in boxes_list]
  240. def unbind(self: T, dim: int = 0) -> T:
  241. """Reload ``unbind`` from self.tensor."""
  242. assert dim != -1 and dim != self.tensor.dim() - 1
  243. boxes_list = self.tensor.unbind(dim=dim)
  244. return [type(self)(boxes, clone=False) for boxes in boxes_list]
  245. def flatten(self: T, start_dim: int = 0, end_dim: int = -2) -> T:
  246. """Reload ``flatten`` from self.tensor."""
  247. assert end_dim != -1 and end_dim != self.tensor.dim() - 1
  248. return type(self)(self.tensor.flatten(start_dim, end_dim), clone=False)
  249. def squeeze(self: T, dim: Optional[int] = None) -> T:
  250. """Reload ``squeeze`` from self.tensor."""
  251. boxes = self.tensor.squeeze() if dim is None else \
  252. self.tensor.squeeze(dim)
  253. return type(self)(boxes, clone=False)
  254. def unsqueeze(self: T, dim: int) -> T:
  255. """Reload ``unsqueeze`` from self.tensor."""
  256. assert dim != -1 and dim != self.tensor.dim()
  257. return type(self)(self.tensor.unsqueeze(dim), clone=False)
  258. @classmethod
  259. def cat(cls: Type[T], box_list: Sequence[T], dim: int = 0) -> T:
  260. """Cancatenates a box instance list into one single box instance.
  261. Similar to ``torch.cat``.
  262. Args:
  263. box_list (Sequence[T]): A sequence of box instances.
  264. dim (int): The dimension over which the box are concatenated.
  265. Defaults to 0.
  266. Returns:
  267. T: Concatenated box instance.
  268. """
  269. assert isinstance(box_list, Sequence)
  270. if len(box_list) == 0:
  271. raise ValueError('box_list should not be a empty list.')
  272. assert dim != -1 and dim != box_list[0].dim() - 1
  273. assert all(isinstance(boxes, cls) for boxes in box_list)
  274. th_box_list = [boxes.tensor for boxes in box_list]
  275. return cls(torch.cat(th_box_list, dim=dim), clone=False)
  276. @classmethod
  277. def stack(cls: Type[T], box_list: Sequence[T], dim: int = 0) -> T:
  278. """Concatenates a sequence of tensors along a new dimension. Similar to
  279. ``torch.stack``.
  280. Args:
  281. box_list (Sequence[T]): A sequence of box instances.
  282. dim (int): Dimension to insert. Defaults to 0.
  283. Returns:
  284. T: Concatenated box instance.
  285. """
  286. assert isinstance(box_list, Sequence)
  287. if len(box_list) == 0:
  288. raise ValueError('box_list should not be a empty list.')
  289. assert dim != -1 and dim != box_list[0].dim()
  290. assert all(isinstance(boxes, cls) for boxes in box_list)
  291. th_box_list = [boxes.tensor for boxes in box_list]
  292. return cls(torch.stack(th_box_list, dim=dim), clone=False)
  293. @abstractproperty
  294. def centers(self) -> Tensor:
  295. """Return a tensor representing the centers of boxes."""
  296. pass
  297. @abstractproperty
  298. def areas(self) -> Tensor:
  299. """Return a tensor representing the areas of boxes."""
  300. pass
  301. @abstractproperty
  302. def widths(self) -> Tensor:
  303. """Return a tensor representing the widths of boxes."""
  304. pass
  305. @abstractproperty
  306. def heights(self) -> Tensor:
  307. """Return a tensor representing the heights of boxes."""
  308. pass
  309. @abstractmethod
  310. def flip_(self,
  311. img_shape: Tuple[int, int],
  312. direction: str = 'horizontal') -> None:
  313. """Flip boxes horizontally or vertically in-place.
  314. Args:
  315. img_shape (Tuple[int, int]): A tuple of image height and width.
  316. direction (str): Flip direction, options are "horizontal",
  317. "vertical" and "diagonal". Defaults to "horizontal"
  318. """
  319. pass
  320. @abstractmethod
  321. def translate_(self, distances: Tuple[float, float]) -> None:
  322. """Translate boxes in-place.
  323. Args:
  324. distances (Tuple[float, float]): translate distances. The first
  325. is horizontal distance and the second is vertical distance.
  326. """
  327. pass
  328. @abstractmethod
  329. def clip_(self, img_shape: Tuple[int, int]) -> None:
  330. """Clip boxes according to the image shape in-place.
  331. Args:
  332. img_shape (Tuple[int, int]): A tuple of image height and width.
  333. """
  334. pass
  335. @abstractmethod
  336. def rotate_(self, center: Tuple[float, float], angle: float) -> None:
  337. """Rotate all boxes in-place.
  338. Args:
  339. center (Tuple[float, float]): Rotation origin.
  340. angle (float): Rotation angle represented in degrees. Positive
  341. values mean clockwise rotation.
  342. """
  343. pass
  344. @abstractmethod
  345. def project_(self, homography_matrix: Union[Tensor, np.ndarray]) -> None:
  346. """Geometric transformat boxes in-place.
  347. Args:
  348. homography_matrix (Tensor or np.ndarray]):
  349. Shape (3, 3) for geometric transformation.
  350. """
  351. pass
  352. @abstractmethod
  353. def rescale_(self, scale_factor: Tuple[float, float]) -> None:
  354. """Rescale boxes w.r.t. rescale_factor in-place.
  355. Note:
  356. Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes
  357. w.r.t ``scale_facotr``. The difference is that ``resize_`` only
  358. changes the width and the height of boxes, but ``rescale_`` also
  359. rescales the box centers simultaneously.
  360. Args:
  361. scale_factor (Tuple[float, float]): factors for scaling boxes.
  362. The length should be 2.
  363. """
  364. pass
  365. @abstractmethod
  366. def resize_(self, scale_factor: Tuple[float, float]) -> None:
  367. """Resize the box width and height w.r.t scale_factor in-place.
  368. Note:
  369. Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes
  370. w.r.t ``scale_facotr``. The difference is that ``resize_`` only
  371. changes the width and the height of boxes, but ``rescale_`` also
  372. rescales the box centers simultaneously.
  373. Args:
  374. scale_factor (Tuple[float, float]): factors for scaling box
  375. shapes. The length should be 2.
  376. """
  377. pass
  378. @abstractmethod
  379. def is_inside(self,
  380. img_shape: Tuple[int, int],
  381. all_inside: bool = False,
  382. allowed_border: int = 0) -> BoolTensor:
  383. """Find boxes inside the image.
  384. Args:
  385. img_shape (Tuple[int, int]): A tuple of image height and width.
  386. all_inside (bool): Whether the boxes are all inside the image or
  387. part inside the image. Defaults to False.
  388. allowed_border (int): Boxes that extend beyond the image shape
  389. boundary by more than ``allowed_border`` are considered
  390. "outside" Defaults to 0.
  391. Returns:
  392. BoolTensor: A BoolTensor indicating whether the box is inside
  393. the image. Assuming the original boxes have shape (m, n, box_dim),
  394. the output has shape (m, n).
  395. """
  396. pass
  397. @abstractmethod
  398. def find_inside_points(self,
  399. points: Tensor,
  400. is_aligned: bool = False) -> BoolTensor:
  401. """Find inside box points. Boxes dimension must be 2.
  402. Args:
  403. points (Tensor): Points coordinates. Has shape of (m, 2).
  404. is_aligned (bool): Whether ``points`` has been aligned with boxes
  405. or not. If True, the length of boxes and ``points`` should be
  406. the same. Defaults to False.
  407. Returns:
  408. BoolTensor: A BoolTensor indicating whether a point is inside
  409. boxes. Assuming the boxes has shape of (n, box_dim), if
  410. ``is_aligned`` is False. The index has shape of (m, n). If
  411. ``is_aligned`` is True, m should be equal to n and the index has
  412. shape of (m, ).
  413. """
  414. pass
  415. @abstractstaticmethod
  416. def overlaps(boxes1: 'BaseBoxes',
  417. boxes2: 'BaseBoxes',
  418. mode: str = 'iou',
  419. is_aligned: bool = False,
  420. eps: float = 1e-6) -> Tensor:
  421. """Calculate overlap between two set of boxes with their types
  422. converted to the present box type.
  423. Args:
  424. boxes1 (:obj:`BaseBoxes`): BaseBoxes with shape of (m, box_dim)
  425. or empty.
  426. boxes2 (:obj:`BaseBoxes`): BaseBoxes with shape of (n, box_dim)
  427. or empty.
  428. mode (str): "iou" (intersection over union), "iof" (intersection
  429. over foreground). Defaults to "iou".
  430. is_aligned (bool): If True, then m and n must be equal. Defaults
  431. to False.
  432. eps (float): A value added to the denominator for numerical
  433. stability. Defaults to 1e-6.
  434. Returns:
  435. Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,)
  436. """
  437. pass
  438. @abstractstaticmethod
  439. def from_instance_masks(masks: MaskType) -> 'BaseBoxes':
  440. """Create boxes from instance masks.
  441. Args:
  442. masks (:obj:`BitmapMasks` or :obj:`PolygonMasks`): BitmapMasks or
  443. PolygonMasks instance with length of n.
  444. Returns:
  445. :obj:`BaseBoxes`: Converted boxes with shape of (n, box_dim).
  446. """
  447. pass