grid_assigner.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.registry import TASK_UTILS
  6. from mmdet.utils import ConfigType
  7. from .assign_result import AssignResult
  8. from .base_assigner import BaseAssigner
  9. @TASK_UTILS.register_module()
  10. class GridAssigner(BaseAssigner):
  11. """Assign a corresponding gt bbox or background to each bbox.
  12. Each proposals will be assigned with `-1`, `0`, or a positive integer
  13. indicating the ground truth index.
  14. - -1: don't care
  15. - 0: negative sample, no assigned gt
  16. - positive integer: positive sample, index (1-based) of assigned gt
  17. Args:
  18. pos_iou_thr (float): IoU threshold for positive bboxes.
  19. neg_iou_thr (float or tuple[float, float]): IoU threshold for negative
  20. bboxes.
  21. min_pos_iou (float): Minimum iou for a bbox to be considered as a
  22. positive bbox. Positive samples can have smaller IoU than
  23. pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
  24. Defaults to 0.
  25. gt_max_assign_all (bool): Whether to assign all bboxes with the same
  26. highest overlap with some gt to that gt.
  27. iou_calculator (:obj:`ConfigDict` or dict): Config of overlaps
  28. Calculator.
  29. """
  30. def __init__(
  31. self,
  32. pos_iou_thr: float,
  33. neg_iou_thr: Union[float, Tuple[float, float]],
  34. min_pos_iou: float = .0,
  35. gt_max_assign_all: bool = True,
  36. iou_calculator: ConfigType = dict(type='BboxOverlaps2D')
  37. ) -> None:
  38. self.pos_iou_thr = pos_iou_thr
  39. self.neg_iou_thr = neg_iou_thr
  40. self.min_pos_iou = min_pos_iou
  41. self.gt_max_assign_all = gt_max_assign_all
  42. self.iou_calculator = TASK_UTILS.build(iou_calculator)
  43. def assign(self,
  44. pred_instances: InstanceData,
  45. gt_instances: InstanceData,
  46. gt_instances_ignore: Optional[InstanceData] = None,
  47. **kwargs) -> AssignResult:
  48. """Assign gt to bboxes. The process is very much like the max iou
  49. assigner, except that positive samples are constrained within the cell
  50. that the gt boxes fell in.
  51. This method assign a gt bbox to every bbox (proposal/anchor), each bbox
  52. will be assigned with -1, 0, or a positive number. -1 means don't care,
  53. 0 means negative sample, positive number is the index (1-based) of
  54. assigned gt.
  55. The assignment is done in following steps, the order matters.
  56. 1. assign every bbox to -1
  57. 2. assign proposals whose iou with all gts <= neg_iou_thr to 0
  58. 3. for each bbox within a cell, if the iou with its nearest gt >
  59. pos_iou_thr and the center of that gt falls inside the cell,
  60. assign it to that bbox
  61. 4. for each gt bbox, assign its nearest proposals within the cell the
  62. gt bbox falls in to itself.
  63. Args:
  64. pred_instances (:obj:`InstanceData`): Instances of model
  65. predictions. It includes ``priors``, and the priors can
  66. be anchors or points, or the bboxes predicted by the
  67. previous stage, has shape (n, 4). The bboxes predicted by
  68. the current model or stage will be named ``bboxes``,
  69. ``labels``, and ``scores``, the same as the ``InstanceData``
  70. in other places.
  71. gt_instances (:obj:`InstanceData`): Ground truth of instance
  72. annotations. It usually includes ``bboxes``, with shape (k, 4),
  73. and ``labels``, with shape (k, ).
  74. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  75. to be ignored during training. It includes ``bboxes``
  76. attribute data that is ignored during training and testing.
  77. Defaults to None.
  78. Returns:
  79. :obj:`AssignResult`: The assign result.
  80. """
  81. gt_bboxes = gt_instances.bboxes
  82. gt_labels = gt_instances.labels
  83. priors = pred_instances.priors
  84. responsible_flags = pred_instances.responsible_flags
  85. num_gts, num_priors = gt_bboxes.size(0), priors.size(0)
  86. # compute iou between all gt and priors
  87. overlaps = self.iou_calculator(gt_bboxes, priors)
  88. # 1. assign -1 by default
  89. assigned_gt_inds = overlaps.new_full((num_priors, ),
  90. -1,
  91. dtype=torch.long)
  92. if num_gts == 0 or num_priors == 0:
  93. # No ground truth or priors, return empty assignment
  94. max_overlaps = overlaps.new_zeros((num_priors, ))
  95. if num_gts == 0:
  96. # No truth, assign everything to background
  97. assigned_gt_inds[:] = 0
  98. assigned_labels = overlaps.new_full((num_priors, ),
  99. -1,
  100. dtype=torch.long)
  101. return AssignResult(
  102. num_gts,
  103. assigned_gt_inds,
  104. max_overlaps,
  105. labels=assigned_labels)
  106. # 2. assign negative: below
  107. # for each anchor, which gt best overlaps with it
  108. # for each anchor, the max iou of all gts
  109. # shape of max_overlaps == argmax_overlaps == num_priors
  110. max_overlaps, argmax_overlaps = overlaps.max(dim=0)
  111. if isinstance(self.neg_iou_thr, float):
  112. assigned_gt_inds[(max_overlaps >= 0)
  113. & (max_overlaps <= self.neg_iou_thr)] = 0
  114. elif isinstance(self.neg_iou_thr, (tuple, list)):
  115. assert len(self.neg_iou_thr) == 2
  116. assigned_gt_inds[(max_overlaps > self.neg_iou_thr[0])
  117. & (max_overlaps <= self.neg_iou_thr[1])] = 0
  118. # 3. assign positive: falls into responsible cell and above
  119. # positive IOU threshold, the order matters.
  120. # the prior condition of comparison is to filter out all
  121. # unrelated anchors, i.e. not responsible_flags
  122. overlaps[:, ~responsible_flags.type(torch.bool)] = -1.
  123. # calculate max_overlaps again, but this time we only consider IOUs
  124. # for anchors responsible for prediction
  125. max_overlaps, argmax_overlaps = overlaps.max(dim=0)
  126. # for each gt, which anchor best overlaps with it
  127. # for each gt, the max iou of all proposals
  128. # shape of gt_max_overlaps == gt_argmax_overlaps == num_gts
  129. gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
  130. pos_inds = (max_overlaps > self.pos_iou_thr) & responsible_flags.type(
  131. torch.bool)
  132. assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
  133. # 4. assign positive to max overlapped anchors within responsible cell
  134. for i in range(num_gts):
  135. if gt_max_overlaps[i] > self.min_pos_iou:
  136. if self.gt_max_assign_all:
  137. max_iou_inds = (overlaps[i, :] == gt_max_overlaps[i]) & \
  138. responsible_flags.type(torch.bool)
  139. assigned_gt_inds[max_iou_inds] = i + 1
  140. elif responsible_flags[gt_argmax_overlaps[i]]:
  141. assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
  142. # assign labels of positive anchors
  143. assigned_labels = assigned_gt_inds.new_full((num_priors, ), -1)
  144. pos_inds = torch.nonzero(
  145. assigned_gt_inds > 0, as_tuple=False).squeeze()
  146. if pos_inds.numel() > 0:
  147. assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
  148. 1]
  149. return AssignResult(
  150. num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)