uniform_assigner.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.registry import TASK_UTILS
  6. from mmdet.structures.bbox import bbox_xyxy_to_cxcywh
  7. from mmdet.utils import ConfigType
  8. from .assign_result import AssignResult
  9. from .base_assigner import BaseAssigner
  10. @TASK_UTILS.register_module()
  11. class UniformAssigner(BaseAssigner):
  12. """Uniform Matching between the priors and gt boxes, which can achieve
  13. balance in positive priors, and gt_bboxes_ignore was not considered for
  14. now.
  15. Args:
  16. pos_ignore_thr (float): the threshold to ignore positive priors
  17. neg_ignore_thr (float): the threshold to ignore negative priors
  18. match_times(int): Number of positive priors for each gt box.
  19. Defaults to 4.
  20. iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou
  21. calculator. Defaults to ``dict(type='BboxOverlaps2D')``
  22. """
  23. def __init__(self,
  24. pos_ignore_thr: float,
  25. neg_ignore_thr: float,
  26. match_times: int = 4,
  27. iou_calculator: ConfigType = dict(type='BboxOverlaps2D')):
  28. self.match_times = match_times
  29. self.pos_ignore_thr = pos_ignore_thr
  30. self.neg_ignore_thr = neg_ignore_thr
  31. self.iou_calculator = TASK_UTILS.build(iou_calculator)
  32. def assign(
  33. self,
  34. pred_instances: InstanceData,
  35. gt_instances: InstanceData,
  36. gt_instances_ignore: Optional[InstanceData] = None
  37. ) -> AssignResult:
  38. """Assign gt to priors.
  39. The assignment is done in following steps
  40. 1. assign -1 by default
  41. 2. compute the L1 cost between boxes. Note that we use priors and
  42. predict boxes both
  43. 3. compute the ignore indexes use gt_bboxes and predict boxes
  44. 4. compute the ignore indexes of positive sample use priors and
  45. predict boxes
  46. Args:
  47. pred_instances (:obj:`InstaceData`): Instances of model
  48. predictions. It includes ``priors``, and the priors can
  49. be priors, points, or bboxes predicted by the model,
  50. shape(n, 4).
  51. gt_instances (:obj:`InstaceData`): Ground truth of instance
  52. annotations. It usually includes ``bboxes`` and ``labels``
  53. attributes.
  54. gt_instances_ignore (:obj:`InstaceData`, optional): Instances
  55. to be ignored during training. It includes ``bboxes``
  56. attribute data that is ignored during training and testing.
  57. Defaults to None.
  58. Returns:
  59. :obj:`AssignResult`: The assign result.
  60. """
  61. gt_bboxes = gt_instances.bboxes
  62. gt_labels = gt_instances.labels
  63. priors = pred_instances.priors
  64. bbox_pred = pred_instances.decoder_priors
  65. num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
  66. # 1. assign -1 by default
  67. assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
  68. 0,
  69. dtype=torch.long)
  70. assigned_labels = bbox_pred.new_full((num_bboxes, ),
  71. -1,
  72. dtype=torch.long)
  73. if num_gts == 0 or num_bboxes == 0:
  74. # No ground truth or boxes, return empty assignment
  75. if num_gts == 0:
  76. # No ground truth, assign all to background
  77. assigned_gt_inds[:] = 0
  78. assign_result = AssignResult(
  79. num_gts, assigned_gt_inds, None, labels=assigned_labels)
  80. assign_result.set_extra_property(
  81. 'pos_idx', bbox_pred.new_empty(0, dtype=torch.bool))
  82. assign_result.set_extra_property('pos_predicted_boxes',
  83. bbox_pred.new_empty((0, 4)))
  84. assign_result.set_extra_property('target_boxes',
  85. bbox_pred.new_empty((0, 4)))
  86. return assign_result
  87. # 2. Compute the L1 cost between boxes
  88. # Note that we use priors and predict boxes both
  89. cost_bbox = torch.cdist(
  90. bbox_xyxy_to_cxcywh(bbox_pred),
  91. bbox_xyxy_to_cxcywh(gt_bboxes),
  92. p=1)
  93. cost_bbox_priors = torch.cdist(
  94. bbox_xyxy_to_cxcywh(priors), bbox_xyxy_to_cxcywh(gt_bboxes), p=1)
  95. # We found that topk function has different results in cpu and
  96. # cuda mode. In order to ensure consistency with the source code,
  97. # we also use cpu mode.
  98. # TODO: Check whether the performance of cpu and cuda are the same.
  99. C = cost_bbox.cpu()
  100. C1 = cost_bbox_priors.cpu()
  101. # self.match_times x n
  102. index = torch.topk(
  103. C, # c=b,n,x c[i]=n,x
  104. k=self.match_times,
  105. dim=0,
  106. largest=False)[1]
  107. # self.match_times x n
  108. index1 = torch.topk(C1, k=self.match_times, dim=0, largest=False)[1]
  109. # (self.match_times*2) x n
  110. indexes = torch.cat((index, index1),
  111. dim=1).reshape(-1).to(bbox_pred.device)
  112. pred_overlaps = self.iou_calculator(bbox_pred, gt_bboxes)
  113. anchor_overlaps = self.iou_calculator(priors, gt_bboxes)
  114. pred_max_overlaps, _ = pred_overlaps.max(dim=1)
  115. anchor_max_overlaps, _ = anchor_overlaps.max(dim=0)
  116. # 3. Compute the ignore indexes use gt_bboxes and predict boxes
  117. ignore_idx = pred_max_overlaps > self.neg_ignore_thr
  118. assigned_gt_inds[ignore_idx] = -1
  119. # 4. Compute the ignore indexes of positive sample use priors
  120. # and predict boxes
  121. pos_gt_index = torch.arange(
  122. 0, C1.size(1),
  123. device=bbox_pred.device).repeat(self.match_times * 2)
  124. pos_ious = anchor_overlaps[indexes, pos_gt_index]
  125. pos_ignore_idx = pos_ious < self.pos_ignore_thr
  126. pos_gt_index_with_ignore = pos_gt_index + 1
  127. pos_gt_index_with_ignore[pos_ignore_idx] = -1
  128. assigned_gt_inds[indexes] = pos_gt_index_with_ignore
  129. if gt_labels is not None:
  130. assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
  131. pos_inds = torch.nonzero(
  132. assigned_gt_inds > 0, as_tuple=False).squeeze()
  133. if pos_inds.numel() > 0:
  134. assigned_labels[pos_inds] = gt_labels[
  135. assigned_gt_inds[pos_inds] - 1]
  136. else:
  137. assigned_labels = None
  138. assign_result = AssignResult(
  139. num_gts,
  140. assigned_gt_inds,
  141. anchor_max_overlaps,
  142. labels=assigned_labels)
  143. assign_result.set_extra_property('pos_idx', ~pos_ignore_idx)
  144. assign_result.set_extra_property('pos_predicted_boxes',
  145. bbox_pred[indexes])
  146. assign_result.set_extra_property('target_boxes',
  147. gt_bboxes[pos_gt_index])
  148. return assign_result