123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional, Union
- import torch
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import TASK_UTILS
- from .assign_result import AssignResult
- from .base_assigner import BaseAssigner
- @TASK_UTILS.register_module()
- class MaxIoUAssigner(BaseAssigner):
- """Assign a corresponding gt bbox or background to each bbox.
- Each proposals will be assigned with `-1`, or a semi-positive integer
- indicating the ground truth index.
- - -1: negative sample, no assigned gt
- - semi-positive integer: positive sample, index (0-based) of assigned gt
- Args:
- pos_iou_thr (float): IoU threshold for positive bboxes.
- neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
- min_pos_iou (float): Minimum iou for a bbox to be considered as a
- positive bbox. Positive samples can have smaller IoU than
- pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
- `min_pos_iou` is set to avoid assigning bboxes that have extremely
- small iou with GT as positive samples. It brings about 0.3 mAP
- improvements in 1x schedule but does not affect the performance of
- 3x schedule. More comparisons can be found in
- `PR #7464 <https://github.com/open-mmlab/mmdetection/pull/7464>`_.
- gt_max_assign_all (bool): Whether to assign all bboxes with the same
- highest overlap with some gt to that gt.
- ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
- `gt_bboxes_ignore` is specified). Negative values mean not
- ignoring any bboxes.
- ignore_wrt_candidates (bool): Whether to compute the iof between
- `bboxes` and `gt_bboxes_ignore`, or the contrary.
- match_low_quality (bool): Whether to allow low quality matches. This is
- usually allowed for RPN and single stage detectors, but not allowed
- in the second stage. Details are demonstrated in Step 4.
- gpu_assign_thr (int): The upper bound of the number of GT for GPU
- assign. When the number of gt is above this threshold, will assign
- on CPU device. Negative values mean not assign on CPU.
- iou_calculator (dict): Config of overlaps Calculator.
- """
- def __init__(self,
- pos_iou_thr: float,
- neg_iou_thr: Union[float, tuple],
- min_pos_iou: float = .0,
- gt_max_assign_all: bool = True,
- ignore_iof_thr: float = -1,
- ignore_wrt_candidates: bool = True,
- match_low_quality: bool = True,
- gpu_assign_thr: float = -1,
- iou_calculator: dict = dict(type='BboxOverlaps2D')):
- self.pos_iou_thr = pos_iou_thr
- self.neg_iou_thr = neg_iou_thr
- self.min_pos_iou = min_pos_iou
- self.gt_max_assign_all = gt_max_assign_all
- self.ignore_iof_thr = ignore_iof_thr
- self.ignore_wrt_candidates = ignore_wrt_candidates
- self.gpu_assign_thr = gpu_assign_thr
- self.match_low_quality = match_low_quality
- self.iou_calculator = TASK_UTILS.build(iou_calculator)
- def assign(self,
- pred_instances: InstanceData,
- gt_instances: InstanceData,
- gt_instances_ignore: Optional[InstanceData] = None,
- **kwargs) -> AssignResult:
- """Assign gt to bboxes.
- This method assign a gt bbox to every bbox (proposal/anchor), each bbox
- will be assigned with -1, or a semi-positive number. -1 means negative
- sample, semi-positive number is the index (0-based) of assigned gt.
- The assignment is done in following steps, the order matters.
- 1. assign every bbox to the background
- 2. assign proposals whose iou with all gts < neg_iou_thr to 0
- 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
- assign it to that bbox
- 4. for each gt bbox, assign its nearest proposals (may be more than
- one) to itself
- Args:
- pred_instances (:obj:`InstanceData`): Instances of model
- predictions. It includes ``priors``, and the priors can
- be anchors or points, or the bboxes predicted by the
- previous stage, has shape (n, 4). The bboxes predicted by
- the current model or stage will be named ``bboxes``,
- ``labels``, and ``scores``, the same as the ``InstanceData``
- in other places.
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It usually includes ``bboxes``, with shape (k, 4),
- and ``labels``, with shape (k, ).
- gt_instances_ignore (:obj:`InstanceData`, optional): Instances
- to be ignored during training. It includes ``bboxes``
- attribute data that is ignored during training and testing.
- Defaults to None.
- Returns:
- :obj:`AssignResult`: The assign result.
- Example:
- >>> from mmengine.structures import InstanceData
- >>> self = MaxIoUAssigner(0.5, 0.5)
- >>> pred_instances = InstanceData()
- >>> pred_instances.priors = torch.Tensor([[0, 0, 10, 10],
- ... [10, 10, 20, 20]])
- >>> gt_instances = InstanceData()
- >>> gt_instances.bboxes = torch.Tensor([[0, 0, 10, 9]])
- >>> gt_instances.labels = torch.Tensor([0])
- >>> assign_result = self.assign(pred_instances, gt_instances)
- >>> expected_gt_inds = torch.LongTensor([1, 0])
- >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
- """
- gt_bboxes = gt_instances.bboxes
- priors = pred_instances.priors
- gt_labels = gt_instances.labels
- if gt_instances_ignore is not None:
- gt_bboxes_ignore = gt_instances_ignore.bboxes
- else:
- gt_bboxes_ignore = None
- assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
- gt_bboxes.shape[0] > self.gpu_assign_thr) else False
- # compute overlap and assign gt on CPU when number of GT is large
- if assign_on_cpu:
- device = priors.device
- priors = priors.cpu()
- gt_bboxes = gt_bboxes.cpu()
- gt_labels = gt_labels.cpu()
- if gt_bboxes_ignore is not None:
- gt_bboxes_ignore = gt_bboxes_ignore.cpu()
- overlaps = self.iou_calculator(gt_bboxes, priors)
- if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
- and gt_bboxes_ignore.numel() > 0 and priors.numel() > 0):
- if self.ignore_wrt_candidates:
- ignore_overlaps = self.iou_calculator(
- priors, gt_bboxes_ignore, mode='iof')
- ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
- else:
- ignore_overlaps = self.iou_calculator(
- gt_bboxes_ignore, priors, mode='iof')
- ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
- overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
- assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
- if assign_on_cpu:
- assign_result.gt_inds = assign_result.gt_inds.to(device)
- assign_result.max_overlaps = assign_result.max_overlaps.to(device)
- if assign_result.labels is not None:
- assign_result.labels = assign_result.labels.to(device)
- return assign_result
- def assign_wrt_overlaps(self, overlaps: Tensor,
- gt_labels: Tensor) -> AssignResult:
- """Assign w.r.t. the overlaps of priors with gts.
- Args:
- overlaps (Tensor): Overlaps between k gt_bboxes and n bboxes,
- shape(k, n).
- gt_labels (Tensor): Labels of k gt_bboxes, shape (k, ).
- Returns:
- :obj:`AssignResult`: The assign result.
- """
- num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
- # 1. assign -1 by default
- assigned_gt_inds = overlaps.new_full((num_bboxes, ),
- -1,
- dtype=torch.long)
- if num_gts == 0 or num_bboxes == 0:
- # No ground truth or boxes, return empty assignment
- max_overlaps = overlaps.new_zeros((num_bboxes, ))
- assigned_labels = overlaps.new_full((num_bboxes, ),
- -1,
- dtype=torch.long)
- if num_gts == 0:
- # No truth, assign everything to background
- assigned_gt_inds[:] = 0
- return AssignResult(
- num_gts=num_gts,
- gt_inds=assigned_gt_inds,
- max_overlaps=max_overlaps,
- labels=assigned_labels)
- # for each anchor, which gt best overlaps with it
- # for each anchor, the max iou of all gts
- max_overlaps, argmax_overlaps = overlaps.max(dim=0)
- # for each gt, which anchor best overlaps with it
- # for each gt, the max iou of all proposals
- gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
- # 2. assign negative: below
- # the negative inds are set to be 0
- if isinstance(self.neg_iou_thr, float):
- assigned_gt_inds[(max_overlaps >= 0)
- & (max_overlaps < self.neg_iou_thr)] = 0
- elif isinstance(self.neg_iou_thr, tuple):
- assert len(self.neg_iou_thr) == 2
- assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
- & (max_overlaps < self.neg_iou_thr[1])] = 0
- # 3. assign positive: above positive IoU threshold
- pos_inds = max_overlaps >= self.pos_iou_thr
- assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
- if self.match_low_quality:
- # Low-quality matching will overwrite the assigned_gt_inds assigned
- # in Step 3. Thus, the assigned gt might not be the best one for
- # prediction.
- # For example, if bbox A has 0.9 and 0.8 iou with GT bbox 1 & 2,
- # bbox 1 will be assigned as the best target for bbox A in step 3.
- # However, if GT bbox 2's gt_argmax_overlaps = A, bbox A's
- # assigned_gt_inds will be overwritten to be bbox 2.
- # This might be the reason that it is not used in ROI Heads.
- for i in range(num_gts):
- if gt_max_overlaps[i] >= self.min_pos_iou:
- if self.gt_max_assign_all:
- max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
- assigned_gt_inds[max_iou_inds] = i + 1
- else:
- assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
- assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
- pos_inds = torch.nonzero(
- assigned_gt_inds > 0, as_tuple=False).squeeze()
- if pos_inds.numel() > 0:
- assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
- 1]
- return AssignResult(
- num_gts=num_gts,
- gt_inds=assigned_gt_inds,
- max_overlaps=max_overlaps,
- labels=assigned_labels)
|