123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Optional, Tuple
- import torch
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import TASK_UTILS
- from ..prior_generators import anchor_inside_flags
- from .assign_result import AssignResult
- from .base_assigner import BaseAssigner
- def calc_region(
- bbox: Tensor,
- ratio: float,
- stride: int,
- featmap_size: Optional[Tuple[int, int]] = None) -> Tuple[Tensor]:
- """Calculate region of the box defined by the ratio, the ratio is from the
- center of the box to every edge."""
- # project bbox on the feature
- f_bbox = bbox / stride
- x1 = torch.round((1 - ratio) * f_bbox[0] + ratio * f_bbox[2])
- y1 = torch.round((1 - ratio) * f_bbox[1] + ratio * f_bbox[3])
- x2 = torch.round(ratio * f_bbox[0] + (1 - ratio) * f_bbox[2])
- y2 = torch.round(ratio * f_bbox[1] + (1 - ratio) * f_bbox[3])
- if featmap_size is not None:
- x1 = x1.clamp(min=0, max=featmap_size[1])
- y1 = y1.clamp(min=0, max=featmap_size[0])
- x2 = x2.clamp(min=0, max=featmap_size[1])
- y2 = y2.clamp(min=0, max=featmap_size[0])
- return (x1, y1, x2, y2)
- def anchor_ctr_inside_region_flags(anchors: Tensor, stride: int,
- region: Tuple[Tensor]) -> Tensor:
- """Get the flag indicate whether anchor centers are inside regions."""
- x1, y1, x2, y2 = region
- f_anchors = anchors / stride
- x = (f_anchors[:, 0] + f_anchors[:, 2]) * 0.5
- y = (f_anchors[:, 1] + f_anchors[:, 3]) * 0.5
- flags = (x >= x1) & (x <= x2) & (y >= y1) & (y <= y2)
- return flags
- @TASK_UTILS.register_module()
- class RegionAssigner(BaseAssigner):
- """Assign a corresponding gt bbox or background to each bbox.
- Each proposals will be assigned with `-1`, `0`, or a positive integer
- indicating the ground truth index.
- - -1: don't care
- - 0: negative sample, no assigned gt
- - positive integer: positive sample, index (1-based) of assigned gt
- Args:
- center_ratio (float): ratio of the region in the center of the bbox to
- define positive sample.
- ignore_ratio (float): ratio of the region to define ignore samples.
- """
- def __init__(self,
- center_ratio: float = 0.2,
- ignore_ratio: float = 0.5) -> None:
- self.center_ratio = center_ratio
- self.ignore_ratio = ignore_ratio
- def assign(self,
- pred_instances: InstanceData,
- gt_instances: InstanceData,
- img_meta: dict,
- featmap_sizes: List[Tuple[int, int]],
- num_level_anchors: List[int],
- anchor_scale: int,
- anchor_strides: List[int],
- gt_instances_ignore: Optional[InstanceData] = None,
- allowed_border: int = 0) -> AssignResult:
- """Assign gt to anchors.
- This method assign a gt bbox to every bbox (proposal/anchor), each bbox
- will be assigned with -1, 0, or a positive number. -1 means don't care,
- 0 means negative sample, positive number is the index (1-based) of
- assigned gt.
- The assignment is done in following steps, and the order matters.
- 1. Assign every anchor to 0 (negative)
- 2. (For each gt_bboxes) Compute ignore flags based on ignore_region
- then assign -1 to anchors w.r.t. ignore flags
- 3. (For each gt_bboxes) Compute pos flags based on center_region then
- assign gt_bboxes to anchors w.r.t. pos flags
- 4. (For each gt_bboxes) Compute ignore flags based on adjacent anchor
- level then assign -1 to anchors w.r.t. ignore flags
- 5. Assign anchor outside of image to -1
- Args:
- pred_instances (:obj:`InstanceData`): Instances of model
- predictions. It includes ``priors``, and the priors can
- be anchors or points, or the bboxes predicted by the
- previous stage, has shape (n, 4). The bboxes predicted by
- the current model or stage will be named ``bboxes``,
- ``labels``, and ``scores``, the same as the ``InstanceData``
- in other places.
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It usually includes ``bboxes``, with shape (k, 4),
- and ``labels``, with shape (k, ).
- img_meta (dict): Meta info of image.
- featmap_sizes (list[tuple[int, int]]): Feature map size each level.
- num_level_anchors (list[int]): The number of anchors in each level.
- anchor_scale (int): Scale of the anchor.
- anchor_strides (list[int]): Stride of the anchor.
- gt_instances_ignore (:obj:`InstanceData`, optional): Instances
- to be ignored during training. It includes ``bboxes``
- attribute data that is ignored during training and testing.
- Defaults to None.
- allowed_border (int, optional): The border to allow the valid
- anchor. Defaults to 0.
- Returns:
- :obj:`AssignResult`: The assign result.
- """
- if gt_instances_ignore is not None:
- raise NotImplementedError
- num_gts = len(gt_instances)
- num_bboxes = len(pred_instances)
- gt_bboxes = gt_instances.bboxes
- gt_labels = gt_instances.labels
- flat_anchors = pred_instances.priors
- flat_valid_flags = pred_instances.valid_flags
- mlvl_anchors = torch.split(flat_anchors, num_level_anchors)
- if num_gts == 0 or num_bboxes == 0:
- # No ground truth or boxes, return empty assignment
- max_overlaps = gt_bboxes.new_zeros((num_bboxes, ))
- assigned_gt_inds = gt_bboxes.new_zeros((num_bboxes, ),
- dtype=torch.long)
- assigned_labels = gt_bboxes.new_full((num_bboxes, ),
- -1,
- dtype=torch.long)
- return AssignResult(
- num_gts=num_gts,
- gt_inds=assigned_gt_inds,
- max_overlaps=max_overlaps,
- labels=assigned_labels)
- num_lvls = len(mlvl_anchors)
- r1 = (1 - self.center_ratio) / 2
- r2 = (1 - self.ignore_ratio) / 2
- scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
- (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
- min_anchor_size = scale.new_full(
- (1, ), float(anchor_scale * anchor_strides[0]))
- target_lvls = torch.floor(
- torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
- target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()
- # 1. assign 0 (negative) by default
- mlvl_assigned_gt_inds = []
- mlvl_ignore_flags = []
- for lvl in range(num_lvls):
- assigned_gt_inds = gt_bboxes.new_full((num_level_anchors[lvl], ),
- 0,
- dtype=torch.long)
- ignore_flags = torch.zeros_like(assigned_gt_inds)
- mlvl_assigned_gt_inds.append(assigned_gt_inds)
- mlvl_ignore_flags.append(ignore_flags)
- for gt_id in range(num_gts):
- lvl = target_lvls[gt_id].item()
- featmap_size = featmap_sizes[lvl]
- stride = anchor_strides[lvl]
- anchors = mlvl_anchors[lvl]
- gt_bbox = gt_bboxes[gt_id, :4]
- # Compute regions
- ignore_region = calc_region(gt_bbox, r2, stride, featmap_size)
- ctr_region = calc_region(gt_bbox, r1, stride, featmap_size)
- # 2. Assign -1 to ignore flags
- ignore_flags = anchor_ctr_inside_region_flags(
- anchors, stride, ignore_region)
- mlvl_assigned_gt_inds[lvl][ignore_flags] = -1
- # 3. Assign gt_bboxes to pos flags
- pos_flags = anchor_ctr_inside_region_flags(anchors, stride,
- ctr_region)
- mlvl_assigned_gt_inds[lvl][pos_flags] = gt_id + 1
- # 4. Assign -1 to ignore adjacent lvl
- if lvl > 0:
- d_lvl = lvl - 1
- d_anchors = mlvl_anchors[d_lvl]
- d_featmap_size = featmap_sizes[d_lvl]
- d_stride = anchor_strides[d_lvl]
- d_ignore_region = calc_region(gt_bbox, r2, d_stride,
- d_featmap_size)
- ignore_flags = anchor_ctr_inside_region_flags(
- d_anchors, d_stride, d_ignore_region)
- mlvl_ignore_flags[d_lvl][ignore_flags] = 1
- if lvl < num_lvls - 1:
- u_lvl = lvl + 1
- u_anchors = mlvl_anchors[u_lvl]
- u_featmap_size = featmap_sizes[u_lvl]
- u_stride = anchor_strides[u_lvl]
- u_ignore_region = calc_region(gt_bbox, r2, u_stride,
- u_featmap_size)
- ignore_flags = anchor_ctr_inside_region_flags(
- u_anchors, u_stride, u_ignore_region)
- mlvl_ignore_flags[u_lvl][ignore_flags] = 1
- # 4. (cont.) Assign -1 to ignore adjacent lvl
- for lvl in range(num_lvls):
- ignore_flags = mlvl_ignore_flags[lvl]
- mlvl_assigned_gt_inds[lvl][ignore_flags == 1] = -1
- # 5. Assign -1 to anchor outside of image
- flat_assigned_gt_inds = torch.cat(mlvl_assigned_gt_inds)
- assert (flat_assigned_gt_inds.shape[0] == flat_anchors.shape[0] ==
- flat_valid_flags.shape[0])
- inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags,
- img_meta['img_shape'],
- allowed_border)
- outside_flags = ~inside_flags
- flat_assigned_gt_inds[outside_flags] = -1
- assigned_labels = torch.zeros_like(flat_assigned_gt_inds)
- pos_flags = flat_assigned_gt_inds > 0
- assigned_labels[pos_flags] = gt_labels[flat_assigned_gt_inds[pos_flags]
- - 1]
- return AssignResult(
- num_gts=num_gts,
- gt_inds=flat_assigned_gt_inds,
- max_overlaps=None,
- labels=assigned_labels)
|