123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional
- import torch
- from mmengine.structures import InstanceData
- from mmdet.registry import TASK_UTILS
- from .assign_result import AssignResult
- from .max_iou_assigner import MaxIoUAssigner
- @TASK_UTILS.register_module()
- class MultiInstanceAssigner(MaxIoUAssigner):
- """Assign a corresponding gt bbox or background to each proposal bbox. If
- we need to use a proposal box to generate multiple predict boxes,
- `MultiInstanceAssigner` can assign multiple gt to each proposal box.
- Args:
- num_instance (int): How many bboxes are predicted by each proposal box.
- """
- def __init__(self, num_instance: int = 2, **kwargs):
- super().__init__(**kwargs)
- self.num_instance = num_instance
- def assign(self,
- pred_instances: InstanceData,
- gt_instances: InstanceData,
- gt_instances_ignore: Optional[InstanceData] = None,
- **kwargs) -> AssignResult:
- """Assign gt to bboxes.
- This method assign gt bboxes to every bbox (proposal/anchor), each bbox
- is assigned a set of gts, and the number of gts in this set is defined
- by `self.num_instance`.
- Args:
- pred_instances (:obj:`InstanceData`): Instances of model
- predictions. It includes ``priors``, and the priors can
- be anchors or points, or the bboxes predicted by the
- previous stage, has shape (n, 4). The bboxes predicted by
- the current model or stage will be named ``bboxes``,
- ``labels``, and ``scores``, the same as the ``InstanceData``
- in other places.
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It usually includes ``bboxes``, with shape (k, 4),
- and ``labels``, with shape (k, ).
- gt_instances_ignore (:obj:`InstanceData`, optional): Instances
- to be ignored during training. It includes ``bboxes``
- attribute data that is ignored during training and testing.
- Defaults to None.
- Returns:
- :obj:`AssignResult`: The assign result.
- """
- gt_bboxes = gt_instances.bboxes
- priors = pred_instances.priors
- # Set the FG label to 1 and add ignored annotations
- gt_labels = gt_instances.labels + 1
- if gt_instances_ignore is not None:
- gt_bboxes_ignore = gt_instances_ignore.bboxes
- if hasattr(gt_instances_ignore, 'labels'):
- gt_labels_ignore = gt_instances_ignore.labels
- else:
- gt_labels_ignore = torch.ones_like(gt_bboxes_ignore)[:, 0] * -1
- else:
- gt_bboxes_ignore = None
- gt_labels_ignore = None
- assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
- gt_bboxes.shape[0] > self.gpu_assign_thr) else False
- # compute overlap and assign gt on CPU when number of GT is large
- if assign_on_cpu:
- device = priors.device
- priors = priors.cpu()
- gt_bboxes = gt_bboxes.cpu()
- gt_labels = gt_labels.cpu()
- if gt_bboxes_ignore is not None:
- gt_bboxes_ignore = gt_bboxes_ignore.cpu()
- gt_labels_ignore = gt_labels_ignore.cpu()
- if gt_bboxes_ignore is not None:
- all_bboxes = torch.cat([gt_bboxes, gt_bboxes_ignore], dim=0)
- all_labels = torch.cat([gt_labels, gt_labels_ignore], dim=0)
- else:
- all_bboxes = gt_bboxes
- all_labels = gt_labels
- all_priors = torch.cat([priors, all_bboxes], dim=0)
- overlaps_normal = self.iou_calculator(
- all_priors, all_bboxes, mode='iou')
- overlaps_ignore = self.iou_calculator(
- all_priors, all_bboxes, mode='iof')
- gt_ignore_mask = all_labels.eq(-1).repeat(all_priors.shape[0], 1)
- overlaps_normal = overlaps_normal * ~gt_ignore_mask
- overlaps_ignore = overlaps_ignore * gt_ignore_mask
- overlaps_normal, overlaps_normal_indices = overlaps_normal.sort(
- descending=True, dim=1)
- overlaps_ignore, overlaps_ignore_indices = overlaps_ignore.sort(
- descending=True, dim=1)
- # select the roi with the higher score
- max_overlaps_normal = overlaps_normal[:, :self.num_instance].flatten()
- gt_assignment_normal = overlaps_normal_indices[:, :self.
- num_instance].flatten()
- max_overlaps_ignore = overlaps_ignore[:, :self.num_instance].flatten()
- gt_assignment_ignore = overlaps_ignore_indices[:, :self.
- num_instance].flatten()
- # ignore or not
- ignore_assign_mask = (max_overlaps_normal < self.pos_iou_thr) * (
- max_overlaps_ignore > max_overlaps_normal)
- overlaps = (max_overlaps_normal * ~ignore_assign_mask) + (
- max_overlaps_ignore * ignore_assign_mask)
- gt_assignment = (gt_assignment_normal * ~ignore_assign_mask) + (
- gt_assignment_ignore * ignore_assign_mask)
- assigned_labels = all_labels[gt_assignment]
- fg_mask = (overlaps >= self.pos_iou_thr) * (assigned_labels != -1)
- bg_mask = (overlaps < self.neg_iou_thr) * (overlaps >= 0)
- assigned_labels[fg_mask] = 1
- assigned_labels[bg_mask] = 0
- overlaps = overlaps.reshape(-1, self.num_instance)
- gt_assignment = gt_assignment.reshape(-1, self.num_instance)
- assigned_labels = assigned_labels.reshape(-1, self.num_instance)
- assign_result = AssignResult(
- num_gts=all_bboxes.size(0),
- gt_inds=gt_assignment,
- max_overlaps=overlaps,
- labels=assigned_labels)
- if assign_on_cpu:
- assign_result.gt_inds = assign_result.gt_inds.to(device)
- assign_result.max_overlaps = assign_result.max_overlaps.to(device)
- if assign_result.labels is not None:
- assign_result.labels = assign_result.labels.to(device)
- return assign_result
|