123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- from typing import Optional
- import torch
- from mmengine.structures import InstanceData
- from mmdet.registry import TASK_UTILS
- from mmdet.structures.bbox import bbox_xyxy_to_cxcywh
- from mmdet.utils import ConfigType
- from .assign_result import AssignResult
- from .base_assigner import BaseAssigner
- @TASK_UTILS.register_module()
- class UniformAssigner(BaseAssigner):
- """Uniform Matching between the priors and gt boxes, which can achieve
- balance in positive priors, and gt_bboxes_ignore was not considered for
- now.
- Args:
- pos_ignore_thr (float): the threshold to ignore positive priors
- neg_ignore_thr (float): the threshold to ignore negative priors
- match_times(int): Number of positive priors for each gt box.
- Defaults to 4.
- iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou
- calculator. Defaults to ``dict(type='BboxOverlaps2D')``
- """
- def __init__(self,
- pos_ignore_thr: float,
- neg_ignore_thr: float,
- match_times: int = 4,
- iou_calculator: ConfigType = dict(type='BboxOverlaps2D')):
- self.match_times = match_times
- self.pos_ignore_thr = pos_ignore_thr
- self.neg_ignore_thr = neg_ignore_thr
- self.iou_calculator = TASK_UTILS.build(iou_calculator)
- def assign(
- self,
- pred_instances: InstanceData,
- gt_instances: InstanceData,
- gt_instances_ignore: Optional[InstanceData] = None
- ) -> AssignResult:
- """Assign gt to priors.
- The assignment is done in following steps
- 1. assign -1 by default
- 2. compute the L1 cost between boxes. Note that we use priors and
- predict boxes both
- 3. compute the ignore indexes use gt_bboxes and predict boxes
- 4. compute the ignore indexes of positive sample use priors and
- predict boxes
- Args:
- pred_instances (:obj:`InstaceData`): Instances of model
- predictions. It includes ``priors``, and the priors can
- be priors, points, or bboxes predicted by the model,
- shape(n, 4).
- gt_instances (:obj:`InstaceData`): Ground truth of instance
- annotations. It usually includes ``bboxes`` and ``labels``
- attributes.
- gt_instances_ignore (:obj:`InstaceData`, optional): Instances
- to be ignored during training. It includes ``bboxes``
- attribute data that is ignored during training and testing.
- Defaults to None.
- Returns:
- :obj:`AssignResult`: The assign result.
- """
- gt_bboxes = gt_instances.bboxes
- gt_labels = gt_instances.labels
- priors = pred_instances.priors
- bbox_pred = pred_instances.decoder_priors
- num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
-
- assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
- 0,
- dtype=torch.long)
- assigned_labels = bbox_pred.new_full((num_bboxes, ),
- -1,
- dtype=torch.long)
- if num_gts == 0 or num_bboxes == 0:
-
- if num_gts == 0:
-
- assigned_gt_inds[:] = 0
- assign_result = AssignResult(
- num_gts, assigned_gt_inds, None, labels=assigned_labels)
- assign_result.set_extra_property(
- 'pos_idx', bbox_pred.new_empty(0, dtype=torch.bool))
- assign_result.set_extra_property('pos_predicted_boxes',
- bbox_pred.new_empty((0, 4)))
- assign_result.set_extra_property('target_boxes',
- bbox_pred.new_empty((0, 4)))
- return assign_result
-
-
- cost_bbox = torch.cdist(
- bbox_xyxy_to_cxcywh(bbox_pred),
- bbox_xyxy_to_cxcywh(gt_bboxes),
- p=1)
- cost_bbox_priors = torch.cdist(
- bbox_xyxy_to_cxcywh(priors), bbox_xyxy_to_cxcywh(gt_bboxes), p=1)
-
-
-
-
- C = cost_bbox.cpu()
- C1 = cost_bbox_priors.cpu()
-
- index = torch.topk(
- C,
- k=self.match_times,
- dim=0,
- largest=False)[1]
-
- index1 = torch.topk(C1, k=self.match_times, dim=0, largest=False)[1]
-
- indexes = torch.cat((index, index1),
- dim=1).reshape(-1).to(bbox_pred.device)
- pred_overlaps = self.iou_calculator(bbox_pred, gt_bboxes)
- anchor_overlaps = self.iou_calculator(priors, gt_bboxes)
- pred_max_overlaps, _ = pred_overlaps.max(dim=1)
- anchor_max_overlaps, _ = anchor_overlaps.max(dim=0)
-
- ignore_idx = pred_max_overlaps > self.neg_ignore_thr
- assigned_gt_inds[ignore_idx] = -1
-
-
- pos_gt_index = torch.arange(
- 0, C1.size(1),
- device=bbox_pred.device).repeat(self.match_times * 2)
- pos_ious = anchor_overlaps[indexes, pos_gt_index]
- pos_ignore_idx = pos_ious < self.pos_ignore_thr
- pos_gt_index_with_ignore = pos_gt_index + 1
- pos_gt_index_with_ignore[pos_ignore_idx] = -1
- assigned_gt_inds[indexes] = pos_gt_index_with_ignore
- if gt_labels is not None:
- assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
- pos_inds = torch.nonzero(
- assigned_gt_inds > 0, as_tuple=False).squeeze()
- if pos_inds.numel() > 0:
- assigned_labels[pos_inds] = gt_labels[
- assigned_gt_inds[pos_inds] - 1]
- else:
- assigned_labels = None
- assign_result = AssignResult(
- num_gts,
- assigned_gt_inds,
- anchor_max_overlaps,
- labels=assigned_labels)
- assign_result.set_extra_property('pos_idx', ~pos_ignore_idx)
- assign_result.set_extra_property('pos_predicted_boxes',
- bbox_pred[indexes])
- assign_result.set_extra_property('target_boxes',
- gt_bboxes[pos_gt_index])
- return assign_result
|