sim_ota_assigner.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from mmengine.structures import InstanceData
  6. from torch import Tensor
  7. from mmdet.registry import TASK_UTILS
  8. from mmdet.utils import ConfigType
  9. from .assign_result import AssignResult
  10. from .base_assigner import BaseAssigner
  11. INF = 100000.0
  12. EPS = 1.0e-7
  13. @TASK_UTILS.register_module()
  14. class SimOTAAssigner(BaseAssigner):
  15. """Computes matching between predictions and ground truth.
  16. Args:
  17. center_radius (float): Ground truth center size
  18. to judge whether a prior is in center. Defaults to 2.5.
  19. candidate_topk (int): The candidate top-k which used to
  20. get top-k ious to calculate dynamic-k. Defaults to 10.
  21. iou_weight (float): The scale factor for regression
  22. iou cost. Defaults to 3.0.
  23. cls_weight (float): The scale factor for classification
  24. cost. Defaults to 1.0.
  25. iou_calculator (ConfigType): Config of overlaps Calculator.
  26. Defaults to dict(type='BboxOverlaps2D').
  27. """
  28. def __init__(self,
  29. center_radius: float = 2.5,
  30. candidate_topk: int = 10,
  31. iou_weight: float = 3.0,
  32. cls_weight: float = 1.0,
  33. iou_calculator: ConfigType = dict(type='BboxOverlaps2D')):
  34. self.center_radius = center_radius
  35. self.candidate_topk = candidate_topk
  36. self.iou_weight = iou_weight
  37. self.cls_weight = cls_weight
  38. self.iou_calculator = TASK_UTILS.build(iou_calculator)
  39. def assign(self,
  40. pred_instances: InstanceData,
  41. gt_instances: InstanceData,
  42. gt_instances_ignore: Optional[InstanceData] = None,
  43. **kwargs) -> AssignResult:
  44. """Assign gt to priors using SimOTA.
  45. Args:
  46. pred_instances (:obj:`InstanceData`): Instances of model
  47. predictions. It includes ``priors``, and the priors can
  48. be anchors or points, or the bboxes predicted by the
  49. previous stage, has shape (n, 4). The bboxes predicted by
  50. the current model or stage will be named ``bboxes``,
  51. ``labels``, and ``scores``, the same as the ``InstanceData``
  52. in other places.
  53. gt_instances (:obj:`InstanceData`): Ground truth of instance
  54. annotations. It usually includes ``bboxes``, with shape (k, 4),
  55. and ``labels``, with shape (k, ).
  56. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  57. to be ignored during training. It includes ``bboxes``
  58. attribute data that is ignored during training and testing.
  59. Defaults to None.
  60. Returns:
  61. obj:`AssignResult`: The assigned result.
  62. """
  63. gt_bboxes = gt_instances.bboxes
  64. gt_labels = gt_instances.labels
  65. num_gt = gt_bboxes.size(0)
  66. decoded_bboxes = pred_instances.bboxes
  67. pred_scores = pred_instances.scores
  68. priors = pred_instances.priors
  69. num_bboxes = decoded_bboxes.size(0)
  70. # assign 0 by default
  71. assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ),
  72. 0,
  73. dtype=torch.long)
  74. if num_gt == 0 or num_bboxes == 0:
  75. # No ground truth or boxes, return empty assignment
  76. max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
  77. assigned_labels = decoded_bboxes.new_full((num_bboxes, ),
  78. -1,
  79. dtype=torch.long)
  80. return AssignResult(
  81. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  82. valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
  83. priors, gt_bboxes)
  84. valid_decoded_bbox = decoded_bboxes[valid_mask]
  85. valid_pred_scores = pred_scores[valid_mask]
  86. num_valid = valid_decoded_bbox.size(0)
  87. if num_valid == 0:
  88. # No valid bboxes, return empty assignment
  89. max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
  90. assigned_labels = decoded_bboxes.new_full((num_bboxes, ),
  91. -1,
  92. dtype=torch.long)
  93. return AssignResult(
  94. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  95. pairwise_ious = self.iou_calculator(valid_decoded_bbox, gt_bboxes)
  96. iou_cost = -torch.log(pairwise_ious + EPS)
  97. gt_onehot_label = (
  98. F.one_hot(gt_labels.to(torch.int64),
  99. pred_scores.shape[-1]).float().unsqueeze(0).repeat(
  100. num_valid, 1, 1))
  101. valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
  102. # disable AMP autocast and calculate BCE with FP32 to avoid overflow
  103. with torch.cuda.amp.autocast(enabled=False):
  104. cls_cost = (
  105. F.binary_cross_entropy(
  106. valid_pred_scores.to(dtype=torch.float32),
  107. gt_onehot_label,
  108. reduction='none',
  109. ).sum(-1).to(dtype=valid_pred_scores.dtype))
  110. cost_matrix = (
  111. cls_cost * self.cls_weight + iou_cost * self.iou_weight +
  112. (~is_in_boxes_and_center) * INF)
  113. matched_pred_ious, matched_gt_inds = \
  114. self.dynamic_k_matching(
  115. cost_matrix, pairwise_ious, num_gt, valid_mask)
  116. # convert to AssignResult format
  117. assigned_gt_inds[valid_mask] = matched_gt_inds + 1
  118. assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
  119. assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
  120. max_overlaps = assigned_gt_inds.new_full((num_bboxes, ),
  121. -INF,
  122. dtype=torch.float32)
  123. max_overlaps[valid_mask] = matched_pred_ious
  124. return AssignResult(
  125. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  126. def get_in_gt_and_in_center_info(
  127. self, priors: Tensor, gt_bboxes: Tensor) -> Tuple[Tensor, Tensor]:
  128. """Get the information of which prior is in gt bboxes and gt center
  129. priors."""
  130. num_gt = gt_bboxes.size(0)
  131. repeated_x = priors[:, 0].unsqueeze(1).repeat(1, num_gt)
  132. repeated_y = priors[:, 1].unsqueeze(1).repeat(1, num_gt)
  133. repeated_stride_x = priors[:, 2].unsqueeze(1).repeat(1, num_gt)
  134. repeated_stride_y = priors[:, 3].unsqueeze(1).repeat(1, num_gt)
  135. # is prior centers in gt bboxes, shape: [n_prior, n_gt]
  136. l_ = repeated_x - gt_bboxes[:, 0]
  137. t_ = repeated_y - gt_bboxes[:, 1]
  138. r_ = gt_bboxes[:, 2] - repeated_x
  139. b_ = gt_bboxes[:, 3] - repeated_y
  140. deltas = torch.stack([l_, t_, r_, b_], dim=1)
  141. is_in_gts = deltas.min(dim=1).values > 0
  142. is_in_gts_all = is_in_gts.sum(dim=1) > 0
  143. # is prior centers in gt centers
  144. gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
  145. gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
  146. ct_box_l = gt_cxs - self.center_radius * repeated_stride_x
  147. ct_box_t = gt_cys - self.center_radius * repeated_stride_y
  148. ct_box_r = gt_cxs + self.center_radius * repeated_stride_x
  149. ct_box_b = gt_cys + self.center_radius * repeated_stride_y
  150. cl_ = repeated_x - ct_box_l
  151. ct_ = repeated_y - ct_box_t
  152. cr_ = ct_box_r - repeated_x
  153. cb_ = ct_box_b - repeated_y
  154. ct_deltas = torch.stack([cl_, ct_, cr_, cb_], dim=1)
  155. is_in_cts = ct_deltas.min(dim=1).values > 0
  156. is_in_cts_all = is_in_cts.sum(dim=1) > 0
  157. # in boxes or in centers, shape: [num_priors]
  158. is_in_gts_or_centers = is_in_gts_all | is_in_cts_all
  159. # both in boxes and centers, shape: [num_fg, num_gt]
  160. is_in_boxes_and_centers = (
  161. is_in_gts[is_in_gts_or_centers, :]
  162. & is_in_cts[is_in_gts_or_centers, :])
  163. return is_in_gts_or_centers, is_in_boxes_and_centers
  164. def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor,
  165. num_gt: int,
  166. valid_mask: Tensor) -> Tuple[Tensor, Tensor]:
  167. """Use IoU and matching cost to calculate the dynamic top-k positive
  168. targets."""
  169. matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
  170. # select candidate topk ious for dynamic-k calculation
  171. candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
  172. topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
  173. # calculate dynamic k for each gt
  174. dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
  175. for gt_idx in range(num_gt):
  176. _, pos_idx = torch.topk(
  177. cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
  178. matching_matrix[:, gt_idx][pos_idx] = 1
  179. del topk_ious, dynamic_ks, pos_idx
  180. prior_match_gt_mask = matching_matrix.sum(1) > 1
  181. if prior_match_gt_mask.sum() > 0:
  182. cost_min, cost_argmin = torch.min(
  183. cost[prior_match_gt_mask, :], dim=1)
  184. matching_matrix[prior_match_gt_mask, :] *= 0
  185. matching_matrix[prior_match_gt_mask, cost_argmin] = 1
  186. # get foreground mask inside box and center prior
  187. fg_mask_inboxes = matching_matrix.sum(1) > 0
  188. valid_mask[valid_mask.clone()] = fg_mask_inboxes
  189. matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
  190. matched_pred_ious = (matching_matrix *
  191. pairwise_ious).sum(1)[fg_mask_inboxes]
  192. return matched_pred_ious, matched_gt_inds