max_iou_assigner.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Union
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from torch import Tensor
  6. from mmdet.registry import TASK_UTILS
  7. from .assign_result import AssignResult
  8. from .base_assigner import BaseAssigner
  9. @TASK_UTILS.register_module()
  10. class MaxIoUAssigner(BaseAssigner):
  11. """Assign a corresponding gt bbox or background to each bbox.
  12. Each proposals will be assigned with `-1`, or a semi-positive integer
  13. indicating the ground truth index.
  14. - -1: negative sample, no assigned gt
  15. - semi-positive integer: positive sample, index (0-based) of assigned gt
  16. Args:
  17. pos_iou_thr (float): IoU threshold for positive bboxes.
  18. neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
  19. min_pos_iou (float): Minimum iou for a bbox to be considered as a
  20. positive bbox. Positive samples can have smaller IoU than
  21. pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
  22. `min_pos_iou` is set to avoid assigning bboxes that have extremely
  23. small iou with GT as positive samples. It brings about 0.3 mAP
  24. improvements in 1x schedule but does not affect the performance of
  25. 3x schedule. More comparisons can be found in
  26. `PR #7464 <https://github.com/open-mmlab/mmdetection/pull/7464>`_.
  27. gt_max_assign_all (bool): Whether to assign all bboxes with the same
  28. highest overlap with some gt to that gt.
  29. ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
  30. `gt_bboxes_ignore` is specified). Negative values mean not
  31. ignoring any bboxes.
  32. ignore_wrt_candidates (bool): Whether to compute the iof between
  33. `bboxes` and `gt_bboxes_ignore`, or the contrary.
  34. match_low_quality (bool): Whether to allow low quality matches. This is
  35. usually allowed for RPN and single stage detectors, but not allowed
  36. in the second stage. Details are demonstrated in Step 4.
  37. gpu_assign_thr (int): The upper bound of the number of GT for GPU
  38. assign. When the number of gt is above this threshold, will assign
  39. on CPU device. Negative values mean not assign on CPU.
  40. iou_calculator (dict): Config of overlaps Calculator.
  41. """
  42. def __init__(self,
  43. pos_iou_thr: float,
  44. neg_iou_thr: Union[float, tuple],
  45. min_pos_iou: float = .0,
  46. gt_max_assign_all: bool = True,
  47. ignore_iof_thr: float = -1,
  48. ignore_wrt_candidates: bool = True,
  49. match_low_quality: bool = True,
  50. gpu_assign_thr: float = -1,
  51. iou_calculator: dict = dict(type='BboxOverlaps2D')):
  52. self.pos_iou_thr = pos_iou_thr
  53. self.neg_iou_thr = neg_iou_thr
  54. self.min_pos_iou = min_pos_iou
  55. self.gt_max_assign_all = gt_max_assign_all
  56. self.ignore_iof_thr = ignore_iof_thr
  57. self.ignore_wrt_candidates = ignore_wrt_candidates
  58. self.gpu_assign_thr = gpu_assign_thr
  59. self.match_low_quality = match_low_quality
  60. self.iou_calculator = TASK_UTILS.build(iou_calculator)
  61. def assign(self,
  62. pred_instances: InstanceData,
  63. gt_instances: InstanceData,
  64. gt_instances_ignore: Optional[InstanceData] = None,
  65. **kwargs) -> AssignResult:
  66. """Assign gt to bboxes.
  67. This method assign a gt bbox to every bbox (proposal/anchor), each bbox
  68. will be assigned with -1, or a semi-positive number. -1 means negative
  69. sample, semi-positive number is the index (0-based) of assigned gt.
  70. The assignment is done in following steps, the order matters.
  71. 1. assign every bbox to the background
  72. 2. assign proposals whose iou with all gts < neg_iou_thr to 0
  73. 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
  74. assign it to that bbox
  75. 4. for each gt bbox, assign its nearest proposals (may be more than
  76. one) to itself
  77. Args:
  78. pred_instances (:obj:`InstanceData`): Instances of model
  79. predictions. It includes ``priors``, and the priors can
  80. be anchors or points, or the bboxes predicted by the
  81. previous stage, has shape (n, 4). The bboxes predicted by
  82. the current model or stage will be named ``bboxes``,
  83. ``labels``, and ``scores``, the same as the ``InstanceData``
  84. in other places.
  85. gt_instances (:obj:`InstanceData`): Ground truth of instance
  86. annotations. It usually includes ``bboxes``, with shape (k, 4),
  87. and ``labels``, with shape (k, ).
  88. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  89. to be ignored during training. It includes ``bboxes``
  90. attribute data that is ignored during training and testing.
  91. Defaults to None.
  92. Returns:
  93. :obj:`AssignResult`: The assign result.
  94. Example:
  95. >>> from mmengine.structures import InstanceData
  96. >>> self = MaxIoUAssigner(0.5, 0.5)
  97. >>> pred_instances = InstanceData()
  98. >>> pred_instances.priors = torch.Tensor([[0, 0, 10, 10],
  99. ... [10, 10, 20, 20]])
  100. >>> gt_instances = InstanceData()
  101. >>> gt_instances.bboxes = torch.Tensor([[0, 0, 10, 9]])
  102. >>> gt_instances.labels = torch.Tensor([0])
  103. >>> assign_result = self.assign(pred_instances, gt_instances)
  104. >>> expected_gt_inds = torch.LongTensor([1, 0])
  105. >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
  106. """
  107. gt_bboxes = gt_instances.bboxes
  108. priors = pred_instances.priors
  109. gt_labels = gt_instances.labels
  110. if gt_instances_ignore is not None:
  111. gt_bboxes_ignore = gt_instances_ignore.bboxes
  112. else:
  113. gt_bboxes_ignore = None
  114. assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
  115. gt_bboxes.shape[0] > self.gpu_assign_thr) else False
  116. # compute overlap and assign gt on CPU when number of GT is large
  117. if assign_on_cpu:
  118. device = priors.device
  119. priors = priors.cpu()
  120. gt_bboxes = gt_bboxes.cpu()
  121. gt_labels = gt_labels.cpu()
  122. if gt_bboxes_ignore is not None:
  123. gt_bboxes_ignore = gt_bboxes_ignore.cpu()
  124. overlaps = self.iou_calculator(gt_bboxes, priors)
  125. if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
  126. and gt_bboxes_ignore.numel() > 0 and priors.numel() > 0):
  127. if self.ignore_wrt_candidates:
  128. ignore_overlaps = self.iou_calculator(
  129. priors, gt_bboxes_ignore, mode='iof')
  130. ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
  131. else:
  132. ignore_overlaps = self.iou_calculator(
  133. gt_bboxes_ignore, priors, mode='iof')
  134. ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
  135. overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
  136. assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
  137. if assign_on_cpu:
  138. assign_result.gt_inds = assign_result.gt_inds.to(device)
  139. assign_result.max_overlaps = assign_result.max_overlaps.to(device)
  140. if assign_result.labels is not None:
  141. assign_result.labels = assign_result.labels.to(device)
  142. return assign_result
  143. def assign_wrt_overlaps(self, overlaps: Tensor,
  144. gt_labels: Tensor) -> AssignResult:
  145. """Assign w.r.t. the overlaps of priors with gts.
  146. Args:
  147. overlaps (Tensor): Overlaps between k gt_bboxes and n bboxes,
  148. shape(k, n).
  149. gt_labels (Tensor): Labels of k gt_bboxes, shape (k, ).
  150. Returns:
  151. :obj:`AssignResult`: The assign result.
  152. """
  153. num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
  154. # 1. assign -1 by default
  155. assigned_gt_inds = overlaps.new_full((num_bboxes, ),
  156. -1,
  157. dtype=torch.long)
  158. if num_gts == 0 or num_bboxes == 0:
  159. # No ground truth or boxes, return empty assignment
  160. max_overlaps = overlaps.new_zeros((num_bboxes, ))
  161. assigned_labels = overlaps.new_full((num_bboxes, ),
  162. -1,
  163. dtype=torch.long)
  164. if num_gts == 0:
  165. # No truth, assign everything to background
  166. assigned_gt_inds[:] = 0
  167. return AssignResult(
  168. num_gts=num_gts,
  169. gt_inds=assigned_gt_inds,
  170. max_overlaps=max_overlaps,
  171. labels=assigned_labels)
  172. # for each anchor, which gt best overlaps with it
  173. # for each anchor, the max iou of all gts
  174. max_overlaps, argmax_overlaps = overlaps.max(dim=0)
  175. # for each gt, which anchor best overlaps with it
  176. # for each gt, the max iou of all proposals
  177. gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
  178. # 2. assign negative: below
  179. # the negative inds are set to be 0
  180. if isinstance(self.neg_iou_thr, float):
  181. assigned_gt_inds[(max_overlaps >= 0)
  182. & (max_overlaps < self.neg_iou_thr)] = 0
  183. elif isinstance(self.neg_iou_thr, tuple):
  184. assert len(self.neg_iou_thr) == 2
  185. assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
  186. & (max_overlaps < self.neg_iou_thr[1])] = 0
  187. # 3. assign positive: above positive IoU threshold
  188. pos_inds = max_overlaps >= self.pos_iou_thr
  189. assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
  190. if self.match_low_quality:
  191. # Low-quality matching will overwrite the assigned_gt_inds assigned
  192. # in Step 3. Thus, the assigned gt might not be the best one for
  193. # prediction.
  194. # For example, if bbox A has 0.9 and 0.8 iou with GT bbox 1 & 2,
  195. # bbox 1 will be assigned as the best target for bbox A in step 3.
  196. # However, if GT bbox 2's gt_argmax_overlaps = A, bbox A's
  197. # assigned_gt_inds will be overwritten to be bbox 2.
  198. # This might be the reason that it is not used in ROI Heads.
  199. for i in range(num_gts):
  200. if gt_max_overlaps[i] >= self.min_pos_iou:
  201. if self.gt_max_assign_all:
  202. max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
  203. assigned_gt_inds[max_iou_inds] = i + 1
  204. else:
  205. assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
  206. assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
  207. pos_inds = torch.nonzero(
  208. assigned_gt_inds > 0, as_tuple=False).squeeze()
  209. if pos_inds.numel() > 0:
  210. assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
  211. 1]
  212. return AssignResult(
  213. num_gts=num_gts,
  214. gt_inds=assigned_gt_inds,
  215. max_overlaps=max_overlaps,
  216. labels=assigned_labels)