multi_instance_assigner.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.registry import TASK_UTILS
  6. from .assign_result import AssignResult
  7. from .max_iou_assigner import MaxIoUAssigner
  8. @TASK_UTILS.register_module()
  9. class MultiInstanceAssigner(MaxIoUAssigner):
  10. """Assign a corresponding gt bbox or background to each proposal bbox. If
  11. we need to use a proposal box to generate multiple predict boxes,
  12. `MultiInstanceAssigner` can assign multiple gt to each proposal box.
  13. Args:
  14. num_instance (int): How many bboxes are predicted by each proposal box.
  15. """
  16. def __init__(self, num_instance: int = 2, **kwargs):
  17. super().__init__(**kwargs)
  18. self.num_instance = num_instance
  19. def assign(self,
  20. pred_instances: InstanceData,
  21. gt_instances: InstanceData,
  22. gt_instances_ignore: Optional[InstanceData] = None,
  23. **kwargs) -> AssignResult:
  24. """Assign gt to bboxes.
  25. This method assign gt bboxes to every bbox (proposal/anchor), each bbox
  26. is assigned a set of gts, and the number of gts in this set is defined
  27. by `self.num_instance`.
  28. Args:
  29. pred_instances (:obj:`InstanceData`): Instances of model
  30. predictions. It includes ``priors``, and the priors can
  31. be anchors or points, or the bboxes predicted by the
  32. previous stage, has shape (n, 4). The bboxes predicted by
  33. the current model or stage will be named ``bboxes``,
  34. ``labels``, and ``scores``, the same as the ``InstanceData``
  35. in other places.
  36. gt_instances (:obj:`InstanceData`): Ground truth of instance
  37. annotations. It usually includes ``bboxes``, with shape (k, 4),
  38. and ``labels``, with shape (k, ).
  39. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  40. to be ignored during training. It includes ``bboxes``
  41. attribute data that is ignored during training and testing.
  42. Defaults to None.
  43. Returns:
  44. :obj:`AssignResult`: The assign result.
  45. """
  46. gt_bboxes = gt_instances.bboxes
  47. priors = pred_instances.priors
  48. # Set the FG label to 1 and add ignored annotations
  49. gt_labels = gt_instances.labels + 1
  50. if gt_instances_ignore is not None:
  51. gt_bboxes_ignore = gt_instances_ignore.bboxes
  52. if hasattr(gt_instances_ignore, 'labels'):
  53. gt_labels_ignore = gt_instances_ignore.labels
  54. else:
  55. gt_labels_ignore = torch.ones_like(gt_bboxes_ignore)[:, 0] * -1
  56. else:
  57. gt_bboxes_ignore = None
  58. gt_labels_ignore = None
  59. assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
  60. gt_bboxes.shape[0] > self.gpu_assign_thr) else False
  61. # compute overlap and assign gt on CPU when number of GT is large
  62. if assign_on_cpu:
  63. device = priors.device
  64. priors = priors.cpu()
  65. gt_bboxes = gt_bboxes.cpu()
  66. gt_labels = gt_labels.cpu()
  67. if gt_bboxes_ignore is not None:
  68. gt_bboxes_ignore = gt_bboxes_ignore.cpu()
  69. gt_labels_ignore = gt_labels_ignore.cpu()
  70. if gt_bboxes_ignore is not None:
  71. all_bboxes = torch.cat([gt_bboxes, gt_bboxes_ignore], dim=0)
  72. all_labels = torch.cat([gt_labels, gt_labels_ignore], dim=0)
  73. else:
  74. all_bboxes = gt_bboxes
  75. all_labels = gt_labels
  76. all_priors = torch.cat([priors, all_bboxes], dim=0)
  77. overlaps_normal = self.iou_calculator(
  78. all_priors, all_bboxes, mode='iou')
  79. overlaps_ignore = self.iou_calculator(
  80. all_priors, all_bboxes, mode='iof')
  81. gt_ignore_mask = all_labels.eq(-1).repeat(all_priors.shape[0], 1)
  82. overlaps_normal = overlaps_normal * ~gt_ignore_mask
  83. overlaps_ignore = overlaps_ignore * gt_ignore_mask
  84. overlaps_normal, overlaps_normal_indices = overlaps_normal.sort(
  85. descending=True, dim=1)
  86. overlaps_ignore, overlaps_ignore_indices = overlaps_ignore.sort(
  87. descending=True, dim=1)
  88. # select the roi with the higher score
  89. max_overlaps_normal = overlaps_normal[:, :self.num_instance].flatten()
  90. gt_assignment_normal = overlaps_normal_indices[:, :self.
  91. num_instance].flatten()
  92. max_overlaps_ignore = overlaps_ignore[:, :self.num_instance].flatten()
  93. gt_assignment_ignore = overlaps_ignore_indices[:, :self.
  94. num_instance].flatten()
  95. # ignore or not
  96. ignore_assign_mask = (max_overlaps_normal < self.pos_iou_thr) * (
  97. max_overlaps_ignore > max_overlaps_normal)
  98. overlaps = (max_overlaps_normal * ~ignore_assign_mask) + (
  99. max_overlaps_ignore * ignore_assign_mask)
  100. gt_assignment = (gt_assignment_normal * ~ignore_assign_mask) + (
  101. gt_assignment_ignore * ignore_assign_mask)
  102. assigned_labels = all_labels[gt_assignment]
  103. fg_mask = (overlaps >= self.pos_iou_thr) * (assigned_labels != -1)
  104. bg_mask = (overlaps < self.neg_iou_thr) * (overlaps >= 0)
  105. assigned_labels[fg_mask] = 1
  106. assigned_labels[bg_mask] = 0
  107. overlaps = overlaps.reshape(-1, self.num_instance)
  108. gt_assignment = gt_assignment.reshape(-1, self.num_instance)
  109. assigned_labels = assigned_labels.reshape(-1, self.num_instance)
  110. assign_result = AssignResult(
  111. num_gts=all_bboxes.size(0),
  112. gt_inds=gt_assignment,
  113. max_overlaps=overlaps,
  114. labels=assigned_labels)
  115. if assign_on_cpu:
  116. assign_result.gt_inds = assign_result.gt_inds.to(device)
  117. assign_result.max_overlaps = assign_result.max_overlaps.to(device)
  118. if assign_result.labels is not None:
  119. assign_result.labels = assign_result.labels.to(device)
  120. return assign_result