123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Optional, Tuple, Union
- import numpy as np
- import torch
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule
- from mmengine.config import ConfigDict
- from mmengine.structures import InstanceData
- from torch import Tensor, nn
- from mmdet.models.roi_heads.bbox_heads.bbox_head import BBoxHead
- from mmdet.models.task_modules.samplers import SamplingResult
- from mmdet.models.utils import empty_instances
- from mmdet.registry import MODELS
- from mmdet.structures.bbox import bbox_overlaps
- @MODELS.register_module()
- class MultiInstanceBBoxHead(BBoxHead):
- r"""Bbox head used in CrowdDet.
- .. code-block:: none
- /-> cls convs_1 -> cls fcs_1 -> cls_1
- |--
- | \-> reg convs_1 -> reg fcs_1 -> reg_1
- |
- | /-> cls convs_2 -> cls fcs_2 -> cls_2
- shared convs -> shared fcs |--
- | \-> reg convs_2 -> reg fcs_2 -> reg_2
- |
- | ...
- |
- | /-> cls convs_k -> cls fcs_k -> cls_k
- |--
- \-> reg convs_k -> reg fcs_k -> reg_k
- Args:
- num_instance (int): The number of branches after shared fcs.
- Defaults to 2.
- with_refine (bool): Whether to use refine module. Defaults to False.
- num_shared_convs (int): The number of shared convs. Defaults to 0.
- num_shared_fcs (int): The number of shared fcs. Defaults to 2.
- num_cls_convs (int): The number of cls convs. Defaults to 0.
- num_cls_fcs (int): The number of cls fcs. Defaults to 0.
- num_reg_convs (int): The number of reg convs. Defaults to 0.
- num_reg_fcs (int): The number of reg fcs. Defaults to 0.
- conv_out_channels (int): The number of conv out channels.
- Defaults to 256.
- fc_out_channels (int): The number of fc out channels. Defaults to 1024.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- Defaults to None.
- """ # noqa: W605
- def __init__(self,
- num_instance: int = 2,
- with_refine: bool = False,
- num_shared_convs: int = 0,
- num_shared_fcs: int = 2,
- num_cls_convs: int = 0,
- num_cls_fcs: int = 0,
- num_reg_convs: int = 0,
- num_reg_fcs: int = 0,
- conv_out_channels: int = 256,
- fc_out_channels: int = 1024,
- init_cfg: Optional[Union[dict, ConfigDict]] = None,
- *args,
- **kwargs) -> None:
- super().__init__(*args, init_cfg=init_cfg, **kwargs)
- assert (num_shared_convs + num_shared_fcs + num_cls_convs +
- num_cls_fcs + num_reg_convs + num_reg_fcs > 0)
- assert num_instance == 2, 'Currently only 2 instances are supported'
- if num_cls_convs > 0 or num_reg_convs > 0:
- assert num_shared_fcs == 0
- if not self.with_cls:
- assert num_cls_convs == 0 and num_cls_fcs == 0
- if not self.with_reg:
- assert num_reg_convs == 0 and num_reg_fcs == 0
- self.num_instance = num_instance
- self.num_shared_convs = num_shared_convs
- self.num_shared_fcs = num_shared_fcs
- self.num_cls_convs = num_cls_convs
- self.num_cls_fcs = num_cls_fcs
- self.num_reg_convs = num_reg_convs
- self.num_reg_fcs = num_reg_fcs
- self.conv_out_channels = conv_out_channels
- self.fc_out_channels = fc_out_channels
- self.with_refine = with_refine
- # add shared convs and fcs
- self.shared_convs, self.shared_fcs, last_layer_dim = \
- self._add_conv_fc_branch(
- self.num_shared_convs, self.num_shared_fcs, self.in_channels,
- True)
- self.shared_out_channels = last_layer_dim
- self.relu = nn.ReLU(inplace=True)
- if self.with_refine:
- refine_model_cfg = {
- 'type': 'Linear',
- 'in_features': self.shared_out_channels + 20,
- 'out_features': self.shared_out_channels
- }
- self.shared_fcs_ref = MODELS.build(refine_model_cfg)
- self.fc_cls_ref = nn.ModuleList()
- self.fc_reg_ref = nn.ModuleList()
- self.cls_convs = nn.ModuleList()
- self.cls_fcs = nn.ModuleList()
- self.reg_convs = nn.ModuleList()
- self.reg_fcs = nn.ModuleList()
- self.cls_last_dim = list()
- self.reg_last_dim = list()
- self.fc_cls = nn.ModuleList()
- self.fc_reg = nn.ModuleList()
- for k in range(self.num_instance):
- # add cls specific branch
- cls_convs, cls_fcs, cls_last_dim = self._add_conv_fc_branch(
- self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels)
- self.cls_convs.append(cls_convs)
- self.cls_fcs.append(cls_fcs)
- self.cls_last_dim.append(cls_last_dim)
- # add reg specific branch
- reg_convs, reg_fcs, reg_last_dim = self._add_conv_fc_branch(
- self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels)
- self.reg_convs.append(reg_convs)
- self.reg_fcs.append(reg_fcs)
- self.reg_last_dim.append(reg_last_dim)
- if self.num_shared_fcs == 0 and not self.with_avg_pool:
- if self.num_cls_fcs == 0:
- self.cls_last_dim *= self.roi_feat_area
- if self.num_reg_fcs == 0:
- self.reg_last_dim *= self.roi_feat_area
- if self.with_cls:
- if self.custom_cls_channels:
- cls_channels = self.loss_cls.get_cls_channels(
- self.num_classes)
- else:
- cls_channels = self.num_classes + 1
- cls_predictor_cfg_ = self.cls_predictor_cfg.copy() # deepcopy
- cls_predictor_cfg_.update(
- in_features=self.cls_last_dim[k],
- out_features=cls_channels)
- self.fc_cls.append(MODELS.build(cls_predictor_cfg_))
- if self.with_refine:
- self.fc_cls_ref.append(MODELS.build(cls_predictor_cfg_))
- if self.with_reg:
- out_dim_reg = (4 if self.reg_class_agnostic else 4 *
- self.num_classes)
- reg_predictor_cfg_ = self.reg_predictor_cfg.copy()
- reg_predictor_cfg_.update(
- in_features=self.reg_last_dim[k], out_features=out_dim_reg)
- self.fc_reg.append(MODELS.build(reg_predictor_cfg_))
- if self.with_refine:
- self.fc_reg_ref.append(MODELS.build(reg_predictor_cfg_))
- if init_cfg is None:
- # when init_cfg is None,
- # It has been set to
- # [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))],
- # [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))]
- # after `super(ConvFCBBoxHead, self).__init__()`
- # we only need to append additional configuration
- # for `shared_fcs`, `cls_fcs` and `reg_fcs`
- self.init_cfg += [
- dict(
- type='Xavier',
- distribution='uniform',
- override=[
- dict(name='shared_fcs'),
- dict(name='cls_fcs'),
- dict(name='reg_fcs')
- ])
- ]
- def _add_conv_fc_branch(self,
- num_branch_convs: int,
- num_branch_fcs: int,
- in_channels: int,
- is_shared: bool = False) -> tuple:
- """Add shared or separable branch.
- convs -> avg pool (optional) -> fcs
- """
- last_layer_dim = in_channels
- # add branch specific conv layers
- branch_convs = nn.ModuleList()
- if num_branch_convs > 0:
- for i in range(num_branch_convs):
- conv_in_channels = (
- last_layer_dim if i == 0 else self.conv_out_channels)
- branch_convs.append(
- ConvModule(
- conv_in_channels, self.conv_out_channels, 3,
- padding=1))
- last_layer_dim = self.conv_out_channels
- # add branch specific fc layers
- branch_fcs = nn.ModuleList()
- if num_branch_fcs > 0:
- # for shared branch, only consider self.with_avg_pool
- # for separated branches, also consider self.num_shared_fcs
- if (is_shared
- or self.num_shared_fcs == 0) and not self.with_avg_pool:
- last_layer_dim *= self.roi_feat_area
- for i in range(num_branch_fcs):
- fc_in_channels = (
- last_layer_dim if i == 0 else self.fc_out_channels)
- branch_fcs.append(
- nn.Linear(fc_in_channels, self.fc_out_channels))
- last_layer_dim = self.fc_out_channels
- return branch_convs, branch_fcs, last_layer_dim
- def forward(self, x: Tuple[Tensor]) -> tuple:
- """Forward features from the upstream network.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- Returns:
- tuple: A tuple of classification scores and bbox prediction.
- - cls_score (Tensor): Classification scores for all scale
- levels, each is a 4D-tensor, the channels number is
- num_base_priors * num_classes.
- - bbox_pred (Tensor): Box energies / deltas for all scale
- levels, each is a 4D-tensor, the channels number is
- num_base_priors * 4.
- - cls_score_ref (Tensor): The cls_score after refine model.
- - bbox_pred_ref (Tensor): The bbox_pred after refine model.
- """
- # shared part
- if self.num_shared_convs > 0:
- for conv in self.shared_convs:
- x = conv(x)
- if self.num_shared_fcs > 0:
- if self.with_avg_pool:
- x = self.avg_pool(x)
- x = x.flatten(1)
- for fc in self.shared_fcs:
- x = self.relu(fc(x))
- x_cls = x
- x_reg = x
- # separate branches
- cls_score = list()
- bbox_pred = list()
- for k in range(self.num_instance):
- for conv in self.cls_convs[k]:
- x_cls = conv(x_cls)
- if x_cls.dim() > 2:
- if self.with_avg_pool:
- x_cls = self.avg_pool(x_cls)
- x_cls = x_cls.flatten(1)
- for fc in self.cls_fcs[k]:
- x_cls = self.relu(fc(x_cls))
- for conv in self.reg_convs[k]:
- x_reg = conv(x_reg)
- if x_reg.dim() > 2:
- if self.with_avg_pool:
- x_reg = self.avg_pool(x_reg)
- x_reg = x_reg.flatten(1)
- for fc in self.reg_fcs[k]:
- x_reg = self.relu(fc(x_reg))
- cls_score.append(self.fc_cls[k](x_cls) if self.with_cls else None)
- bbox_pred.append(self.fc_reg[k](x_reg) if self.with_reg else None)
- if self.with_refine:
- x_ref = x
- cls_score_ref = list()
- bbox_pred_ref = list()
- for k in range(self.num_instance):
- feat_ref = cls_score[k].softmax(dim=-1)
- feat_ref = torch.cat((bbox_pred[k], feat_ref[:, 1][:, None]),
- dim=1).repeat(1, 4)
- feat_ref = torch.cat((x_ref, feat_ref), dim=1)
- feat_ref = F.relu_(self.shared_fcs_ref(feat_ref))
- cls_score_ref.append(self.fc_cls_ref[k](feat_ref))
- bbox_pred_ref.append(self.fc_reg_ref[k](feat_ref))
- cls_score = torch.cat(cls_score, dim=1)
- bbox_pred = torch.cat(bbox_pred, dim=1)
- cls_score_ref = torch.cat(cls_score_ref, dim=1)
- bbox_pred_ref = torch.cat(bbox_pred_ref, dim=1)
- return cls_score, bbox_pred, cls_score_ref, bbox_pred_ref
- cls_score = torch.cat(cls_score, dim=1)
- bbox_pred = torch.cat(bbox_pred, dim=1)
- return cls_score, bbox_pred
- def get_targets(self,
- sampling_results: List[SamplingResult],
- rcnn_train_cfg: ConfigDict,
- concat: bool = True) -> tuple:
- """Calculate the ground truth for all samples in a batch according to
- the sampling_results.
- Almost the same as the implementation in bbox_head, we passed
- additional parameters pos_inds_list and neg_inds_list to
- `_get_targets_single` function.
- Args:
- sampling_results (List[obj:SamplingResult]): Assign results of
- all images in a batch after sampling.
- rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
- concat (bool): Whether to concatenate the results of all
- the images in a single batch.
- Returns:
- Tuple[Tensor]: Ground truth for proposals in a single image.
- Containing the following list of Tensors:
- - labels (list[Tensor],Tensor): Gt_labels for all proposals in a
- batch, each tensor in list has shape (num_proposals,) when
- `concat=False`, otherwise just a single tensor has shape
- (num_all_proposals,).
- - label_weights (list[Tensor]): Labels_weights for
- all proposals in a batch, each tensor in list has shape
- (num_proposals,) when `concat=False`, otherwise just a single
- tensor has shape (num_all_proposals,).
- - bbox_targets (list[Tensor],Tensor): Regression target for all
- proposals in a batch, each tensor in list has shape
- (num_proposals, 4) when `concat=False`, otherwise just a single
- tensor has shape (num_all_proposals, 4), the last dimension 4
- represents [tl_x, tl_y, br_x, br_y].
- - bbox_weights (list[tensor],Tensor): Regression weights for
- all proposals in a batch, each tensor in list has shape
- (num_proposals, 4) when `concat=False`, otherwise just a
- single tensor has shape (num_all_proposals, 4).
- """
- labels = []
- bbox_targets = []
- bbox_weights = []
- label_weights = []
- for i in range(len(sampling_results)):
- sample_bboxes = torch.cat([
- sampling_results[i].pos_gt_bboxes,
- sampling_results[i].neg_gt_bboxes
- ])
- sample_priors = sampling_results[i].priors
- sample_priors = sample_priors.repeat(1, self.num_instance).reshape(
- -1, 4)
- sample_bboxes = sample_bboxes.reshape(-1, 4)
- if not self.reg_decoded_bbox:
- _bbox_targets = self.bbox_coder.encode(sample_priors,
- sample_bboxes)
- else:
- _bbox_targets = sample_priors
- _bbox_targets = _bbox_targets.reshape(-1, self.num_instance * 4)
- _bbox_weights = torch.ones(_bbox_targets.shape)
- _labels = torch.cat([
- sampling_results[i].pos_gt_labels,
- sampling_results[i].neg_gt_labels
- ])
- _labels_weights = torch.ones(_labels.shape)
- bbox_targets.append(_bbox_targets)
- bbox_weights.append(_bbox_weights)
- labels.append(_labels)
- label_weights.append(_labels_weights)
- if concat:
- labels = torch.cat(labels, 0)
- label_weights = torch.cat(label_weights, 0)
- bbox_targets = torch.cat(bbox_targets, 0)
- bbox_weights = torch.cat(bbox_weights, 0)
- return labels, label_weights, bbox_targets, bbox_weights
- def loss(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor,
- labels: Tensor, label_weights: Tensor, bbox_targets: Tensor,
- bbox_weights: Tensor, **kwargs) -> dict:
- """Calculate the loss based on the network predictions and targets.
- Args:
- cls_score (Tensor): Classification prediction results of all class,
- has shape (batch_size * num_proposals_single_image,
- (num_classes + 1) * k), k represents the number of prediction
- boxes generated by each proposal box.
- bbox_pred (Tensor): Regression prediction results, has shape
- (batch_size * num_proposals_single_image, 4 * k), the last
- dimension 4 represents [tl_x, tl_y, br_x, br_y].
- rois (Tensor): RoIs with the shape
- (batch_size * num_proposals_single_image, 5) where the first
- column indicates batch id of each RoI.
- labels (Tensor): Gt_labels for all proposals in a batch, has
- shape (batch_size * num_proposals_single_image, k).
- label_weights (Tensor): Labels_weights for all proposals in a
- batch, has shape (batch_size * num_proposals_single_image, k).
- bbox_targets (Tensor): Regression target for all proposals in a
- batch, has shape (batch_size * num_proposals_single_image,
- 4 * k), the last dimension 4 represents [tl_x, tl_y, br_x,
- br_y].
- bbox_weights (Tensor): Regression weights for all proposals in a
- batch, has shape (batch_size * num_proposals_single_image,
- 4 * k).
- Returns:
- dict: A dictionary of loss.
- """
- losses = dict()
- if bbox_pred.numel():
- loss_0 = self.emd_loss(bbox_pred[:, 0:4], cls_score[:, 0:2],
- bbox_pred[:, 4:8], cls_score[:, 2:4],
- bbox_targets, labels)
- loss_1 = self.emd_loss(bbox_pred[:, 4:8], cls_score[:, 2:4],
- bbox_pred[:, 0:4], cls_score[:, 0:2],
- bbox_targets, labels)
- loss = torch.cat([loss_0, loss_1], dim=1)
- _, min_indices = loss.min(dim=1)
- loss_emd = loss[torch.arange(loss.shape[0]), min_indices]
- loss_emd = loss_emd.mean()
- else:
- loss_emd = bbox_pred.sum()
- losses['loss_rcnn_emd'] = loss_emd
- return losses
- def emd_loss(self, bbox_pred_0: Tensor, cls_score_0: Tensor,
- bbox_pred_1: Tensor, cls_score_1: Tensor, targets: Tensor,
- labels: Tensor) -> Tensor:
- """Calculate the emd loss.
- Note:
- This implementation is modified from https://github.com/Purkialo/
- CrowdDet/blob/master/lib/det_oprs/loss_opr.py
- Args:
- bbox_pred_0 (Tensor): Part of regression prediction results, has
- shape (batch_size * num_proposals_single_image, 4), the last
- dimension 4 represents [tl_x, tl_y, br_x, br_y].
- cls_score_0 (Tensor): Part of classification prediction results,
- has shape (batch_size * num_proposals_single_image,
- (num_classes + 1)), where 1 represents the background.
- bbox_pred_1 (Tensor): The other part of regression prediction
- results, has shape (batch_size*num_proposals_single_image, 4).
- cls_score_1 (Tensor):The other part of classification prediction
- results, has shape (batch_size * num_proposals_single_image,
- (num_classes + 1)).
- targets (Tensor):Regression target for all proposals in a
- batch, has shape (batch_size * num_proposals_single_image,
- 4 * k), the last dimension 4 represents [tl_x, tl_y, br_x,
- br_y], k represents the number of prediction boxes generated
- by each proposal box.
- labels (Tensor): Gt_labels for all proposals in a batch, has
- shape (batch_size * num_proposals_single_image, k).
- Returns:
- torch.Tensor: The calculated loss.
- """
- bbox_pred = torch.cat([bbox_pred_0, bbox_pred_1],
- dim=1).reshape(-1, bbox_pred_0.shape[-1])
- cls_score = torch.cat([cls_score_0, cls_score_1],
- dim=1).reshape(-1, cls_score_0.shape[-1])
- targets = targets.reshape(-1, 4)
- labels = labels.long().flatten()
- # masks
- valid_masks = labels >= 0
- fg_masks = labels > 0
- # multiple class
- bbox_pred = bbox_pred.reshape(-1, self.num_classes, 4)
- fg_gt_classes = labels[fg_masks]
- bbox_pred = bbox_pred[fg_masks, fg_gt_classes - 1, :]
- # loss for regression
- loss_bbox = self.loss_bbox(bbox_pred, targets[fg_masks])
- loss_bbox = loss_bbox.sum(dim=1)
- # loss for classification
- labels = labels * valid_masks
- loss_cls = self.loss_cls(cls_score, labels)
- loss_cls[fg_masks] = loss_cls[fg_masks] + loss_bbox
- loss = loss_cls.reshape(-1, 2).sum(dim=1)
- return loss.reshape(-1, 1)
- def _predict_by_feat_single(
- self,
- roi: Tensor,
- cls_score: Tensor,
- bbox_pred: Tensor,
- img_meta: dict,
- rescale: bool = False,
- rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData:
- """Transform a single image's features extracted from the head into
- bbox results.
- Args:
- roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
- last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
- cls_score (Tensor): Box scores, has shape
- (num_boxes, num_classes + 1).
- bbox_pred (Tensor): Box energies / deltas. has shape
- (num_boxes, num_classes * 4).
- img_meta (dict): image information.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head.
- Defaults to None
- Returns:
- :obj:`InstanceData`: Detection results of each image.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- cls_score = cls_score.reshape(-1, self.num_classes + 1)
- bbox_pred = bbox_pred.reshape(-1, 4)
- roi = roi.repeat_interleave(self.num_instance, dim=0)
- results = InstanceData()
- if roi.shape[0] == 0:
- return empty_instances([img_meta],
- roi.device,
- task_type='bbox',
- instance_results=[results])[0]
- scores = cls_score.softmax(dim=-1) if cls_score is not None else None
- img_shape = img_meta['img_shape']
- bboxes = self.bbox_coder.decode(
- roi[..., 1:], bbox_pred, max_shape=img_shape)
- if rescale and bboxes.size(0) > 0:
- assert img_meta.get('scale_factor') is not None
- scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
- (1, 2))
- bboxes = (bboxes.view(bboxes.size(0), -1, 4) / scale_factor).view(
- bboxes.size()[0], -1)
- if rcnn_test_cfg is None:
- # This means that it is aug test.
- # It needs to return the raw results without nms.
- results.bboxes = bboxes
- results.scores = scores
- else:
- roi_idx = np.tile(
- np.arange(bboxes.shape[0] / self.num_instance)[:, None],
- (1, self.num_instance)).reshape(-1, 1)[:, 0]
- roi_idx = torch.from_numpy(roi_idx).to(bboxes.device).reshape(
- -1, 1)
- bboxes = torch.cat([bboxes, roi_idx], dim=1)
- det_bboxes, det_scores = self.set_nms(
- bboxes, scores[:, 1], rcnn_test_cfg.score_thr,
- rcnn_test_cfg.nms['iou_threshold'], rcnn_test_cfg.max_per_img)
- results.bboxes = det_bboxes[:, :-1]
- results.scores = det_scores
- results.labels = torch.zeros_like(det_scores)
- return results
- @staticmethod
- def set_nms(bboxes: Tensor,
- scores: Tensor,
- score_thr: float,
- iou_threshold: float,
- max_num: int = -1) -> Tuple[Tensor, Tensor]:
- """NMS for multi-instance prediction. Please refer to
- https://github.com/Purkialo/CrowdDet for more details.
- Args:
- bboxes (Tensor): predict bboxes.
- scores (Tensor): The score of each predict bbox.
- score_thr (float): bbox threshold, bboxes with scores lower than it
- will not be considered.
- iou_threshold (float): IoU threshold to be considered as
- conflicted.
- max_num (int, optional): if there are more than max_num bboxes
- after NMS, only top max_num will be kept. Default to -1.
- Returns:
- Tuple[Tensor, Tensor]: (bboxes, scores).
- """
- bboxes = bboxes[scores > score_thr]
- scores = scores[scores > score_thr]
- ordered_scores, order = scores.sort(descending=True)
- ordered_bboxes = bboxes[order]
- roi_idx = ordered_bboxes[:, -1]
- keep = torch.ones(len(ordered_bboxes)) == 1
- ruler = torch.arange(len(ordered_bboxes))
- while ruler.shape[0] > 0:
- basement = ruler[0]
- ruler = ruler[1:]
- idx = roi_idx[basement]
- # calculate the body overlap
- basement_bbox = ordered_bboxes[:, :4][basement].reshape(-1, 4)
- ruler_bbox = ordered_bboxes[:, :4][ruler].reshape(-1, 4)
- overlap = bbox_overlaps(basement_bbox, ruler_bbox)
- indices = torch.where(overlap > iou_threshold)[1]
- loc = torch.where(roi_idx[ruler][indices] == idx)
- # the mask won't change in the step
- mask = keep[ruler[indices][loc]]
- keep[ruler[indices]] = False
- keep[ruler[indices][loc][mask]] = True
- ruler[~keep[ruler]] = -1
- ruler = ruler[ruler > 0]
- keep = keep[order.sort()[1]]
- return bboxes[keep][:max_num, :], scores[keep][:max_num]
|