task_aligned_assigner.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.registry import TASK_UTILS
  6. from mmdet.utils import ConfigType
  7. from .assign_result import AssignResult
  8. from .base_assigner import BaseAssigner
  9. INF = 100000000
  10. @TASK_UTILS.register_module()
  11. class TaskAlignedAssigner(BaseAssigner):
  12. """Task aligned assigner used in the paper:
  13. `TOOD: Task-aligned One-stage Object Detection.
  14. <https://arxiv.org/abs/2108.07755>`_.
  15. Assign a corresponding gt bbox or background to each predicted bbox.
  16. Each bbox will be assigned with `0` or a positive integer
  17. indicating the ground truth index.
  18. - 0: negative sample, no assigned gt
  19. - positive integer: positive sample, index (1-based) of assigned gt
  20. Args:
  21. topk (int): number of bbox selected in each level
  22. iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou
  23. calculator. Defaults to ``dict(type='BboxOverlaps2D')``
  24. """
  25. def __init__(self,
  26. topk: int,
  27. iou_calculator: ConfigType = dict(type='BboxOverlaps2D')):
  28. assert topk >= 1
  29. self.topk = topk
  30. self.iou_calculator = TASK_UTILS.build(iou_calculator)
  31. def assign(self,
  32. pred_instances: InstanceData,
  33. gt_instances: InstanceData,
  34. gt_instances_ignore: Optional[InstanceData] = None,
  35. alpha: int = 1,
  36. beta: int = 6) -> AssignResult:
  37. """Assign gt to bboxes.
  38. The assignment is done in following steps
  39. 1. compute alignment metric between all bbox (bbox of all pyramid
  40. levels) and gt
  41. 2. select top-k bbox as candidates for each gt
  42. 3. limit the positive sample's center in gt (because the anchor-free
  43. detector only can predict positive distance)
  44. Args:
  45. pred_instances (:obj:`InstaceData`): Instances of model
  46. predictions. It includes ``priors``, and the priors can
  47. be anchors, points, or bboxes predicted by the model,
  48. shape(n, 4).
  49. gt_instances (:obj:`InstaceData`): Ground truth of instance
  50. annotations. It usually includes ``bboxes`` and ``labels``
  51. attributes.
  52. gt_instances_ignore (:obj:`InstaceData`, optional): Instances
  53. to be ignored during training. It includes ``bboxes``
  54. attribute data that is ignored during training and testing.
  55. Defaults to None.
  56. alpha (int): Hyper-parameters related to alignment_metrics.
  57. Defaults to 1.
  58. beta (int): Hyper-parameters related to alignment_metrics.
  59. Defaults to 6.
  60. Returns:
  61. :obj:`TaskAlignedAssignResult`: The assign result.
  62. """
  63. priors = pred_instances.priors
  64. decode_bboxes = pred_instances.bboxes
  65. pred_scores = pred_instances.scores
  66. gt_bboxes = gt_instances.bboxes
  67. gt_labels = gt_instances.labels
  68. priors = priors[:, :4]
  69. num_gt, num_bboxes = gt_bboxes.size(0), priors.size(0)
  70. # compute alignment metric between all bbox and gt
  71. overlaps = self.iou_calculator(decode_bboxes, gt_bboxes).detach()
  72. bbox_scores = pred_scores[:, gt_labels].detach()
  73. # assign 0 by default
  74. assigned_gt_inds = priors.new_full((num_bboxes, ), 0, dtype=torch.long)
  75. assign_metrics = priors.new_zeros((num_bboxes, ))
  76. if num_gt == 0 or num_bboxes == 0:
  77. # No ground truth or boxes, return empty assignment
  78. max_overlaps = priors.new_zeros((num_bboxes, ))
  79. if num_gt == 0:
  80. # No gt boxes, assign everything to background
  81. assigned_gt_inds[:] = 0
  82. assigned_labels = priors.new_full((num_bboxes, ),
  83. -1,
  84. dtype=torch.long)
  85. assign_result = AssignResult(
  86. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  87. assign_result.assign_metrics = assign_metrics
  88. return assign_result
  89. # select top-k bboxes as candidates for each gt
  90. alignment_metrics = bbox_scores**alpha * overlaps**beta
  91. topk = min(self.topk, alignment_metrics.size(0))
  92. _, candidate_idxs = alignment_metrics.topk(topk, dim=0, largest=True)
  93. candidate_metrics = alignment_metrics[candidate_idxs,
  94. torch.arange(num_gt)]
  95. is_pos = candidate_metrics > 0
  96. # limit the positive sample's center in gt
  97. priors_cx = (priors[:, 0] + priors[:, 2]) / 2.0
  98. priors_cy = (priors[:, 1] + priors[:, 3]) / 2.0
  99. for gt_idx in range(num_gt):
  100. candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
  101. ep_priors_cx = priors_cx.view(1, -1).expand(
  102. num_gt, num_bboxes).contiguous().view(-1)
  103. ep_priors_cy = priors_cy.view(1, -1).expand(
  104. num_gt, num_bboxes).contiguous().view(-1)
  105. candidate_idxs = candidate_idxs.view(-1)
  106. # calculate the left, top, right, bottom distance between positive
  107. # bbox center and gt side
  108. l_ = ep_priors_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
  109. t_ = ep_priors_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
  110. r_ = gt_bboxes[:, 2] - ep_priors_cx[candidate_idxs].view(-1, num_gt)
  111. b_ = gt_bboxes[:, 3] - ep_priors_cy[candidate_idxs].view(-1, num_gt)
  112. is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
  113. is_pos = is_pos & is_in_gts
  114. # if an anchor box is assigned to multiple gts,
  115. # the one with the highest iou will be selected.
  116. overlaps_inf = torch.full_like(overlaps,
  117. -INF).t().contiguous().view(-1)
  118. index = candidate_idxs.view(-1)[is_pos.view(-1)]
  119. overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
  120. overlaps_inf = overlaps_inf.view(num_gt, -1).t()
  121. max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
  122. assigned_gt_inds[
  123. max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
  124. assign_metrics[max_overlaps != -INF] = alignment_metrics[
  125. max_overlaps != -INF, argmax_overlaps[max_overlaps != -INF]]
  126. assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
  127. pos_inds = torch.nonzero(
  128. assigned_gt_inds > 0, as_tuple=False).squeeze()
  129. if pos_inds.numel() > 0:
  130. assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
  131. 1]
  132. assign_result = AssignResult(
  133. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  134. assign_result.assign_metrics = assign_metrics
  135. return assign_result