multi_instance_bbox_head.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple, Union
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule
  7. from mmengine.config import ConfigDict
  8. from mmengine.structures import InstanceData
  9. from torch import Tensor, nn
  10. from mmdet.models.roi_heads.bbox_heads.bbox_head import BBoxHead
  11. from mmdet.models.task_modules.samplers import SamplingResult
  12. from mmdet.models.utils import empty_instances
  13. from mmdet.registry import MODELS
  14. from mmdet.structures.bbox import bbox_overlaps
  15. @MODELS.register_module()
  16. class MultiInstanceBBoxHead(BBoxHead):
  17. r"""Bbox head used in CrowdDet.
  18. .. code-block:: none
  19. /-> cls convs_1 -> cls fcs_1 -> cls_1
  20. |--
  21. | \-> reg convs_1 -> reg fcs_1 -> reg_1
  22. |
  23. | /-> cls convs_2 -> cls fcs_2 -> cls_2
  24. shared convs -> shared fcs |--
  25. | \-> reg convs_2 -> reg fcs_2 -> reg_2
  26. |
  27. | ...
  28. |
  29. | /-> cls convs_k -> cls fcs_k -> cls_k
  30. |--
  31. \-> reg convs_k -> reg fcs_k -> reg_k
  32. Args:
  33. num_instance (int): The number of branches after shared fcs.
  34. Defaults to 2.
  35. with_refine (bool): Whether to use refine module. Defaults to False.
  36. num_shared_convs (int): The number of shared convs. Defaults to 0.
  37. num_shared_fcs (int): The number of shared fcs. Defaults to 2.
  38. num_cls_convs (int): The number of cls convs. Defaults to 0.
  39. num_cls_fcs (int): The number of cls fcs. Defaults to 0.
  40. num_reg_convs (int): The number of reg convs. Defaults to 0.
  41. num_reg_fcs (int): The number of reg fcs. Defaults to 0.
  42. conv_out_channels (int): The number of conv out channels.
  43. Defaults to 256.
  44. fc_out_channels (int): The number of fc out channels. Defaults to 1024.
  45. init_cfg (dict or list[dict], optional): Initialization config dict.
  46. Defaults to None.
  47. """ # noqa: W605
  48. def __init__(self,
  49. num_instance: int = 2,
  50. with_refine: bool = False,
  51. num_shared_convs: int = 0,
  52. num_shared_fcs: int = 2,
  53. num_cls_convs: int = 0,
  54. num_cls_fcs: int = 0,
  55. num_reg_convs: int = 0,
  56. num_reg_fcs: int = 0,
  57. conv_out_channels: int = 256,
  58. fc_out_channels: int = 1024,
  59. init_cfg: Optional[Union[dict, ConfigDict]] = None,
  60. *args,
  61. **kwargs) -> None:
  62. super().__init__(*args, init_cfg=init_cfg, **kwargs)
  63. assert (num_shared_convs + num_shared_fcs + num_cls_convs +
  64. num_cls_fcs + num_reg_convs + num_reg_fcs > 0)
  65. assert num_instance == 2, 'Currently only 2 instances are supported'
  66. if num_cls_convs > 0 or num_reg_convs > 0:
  67. assert num_shared_fcs == 0
  68. if not self.with_cls:
  69. assert num_cls_convs == 0 and num_cls_fcs == 0
  70. if not self.with_reg:
  71. assert num_reg_convs == 0 and num_reg_fcs == 0
  72. self.num_instance = num_instance
  73. self.num_shared_convs = num_shared_convs
  74. self.num_shared_fcs = num_shared_fcs
  75. self.num_cls_convs = num_cls_convs
  76. self.num_cls_fcs = num_cls_fcs
  77. self.num_reg_convs = num_reg_convs
  78. self.num_reg_fcs = num_reg_fcs
  79. self.conv_out_channels = conv_out_channels
  80. self.fc_out_channels = fc_out_channels
  81. self.with_refine = with_refine
  82. # add shared convs and fcs
  83. self.shared_convs, self.shared_fcs, last_layer_dim = \
  84. self._add_conv_fc_branch(
  85. self.num_shared_convs, self.num_shared_fcs, self.in_channels,
  86. True)
  87. self.shared_out_channels = last_layer_dim
  88. self.relu = nn.ReLU(inplace=True)
  89. if self.with_refine:
  90. refine_model_cfg = {
  91. 'type': 'Linear',
  92. 'in_features': self.shared_out_channels + 20,
  93. 'out_features': self.shared_out_channels
  94. }
  95. self.shared_fcs_ref = MODELS.build(refine_model_cfg)
  96. self.fc_cls_ref = nn.ModuleList()
  97. self.fc_reg_ref = nn.ModuleList()
  98. self.cls_convs = nn.ModuleList()
  99. self.cls_fcs = nn.ModuleList()
  100. self.reg_convs = nn.ModuleList()
  101. self.reg_fcs = nn.ModuleList()
  102. self.cls_last_dim = list()
  103. self.reg_last_dim = list()
  104. self.fc_cls = nn.ModuleList()
  105. self.fc_reg = nn.ModuleList()
  106. for k in range(self.num_instance):
  107. # add cls specific branch
  108. cls_convs, cls_fcs, cls_last_dim = self._add_conv_fc_branch(
  109. self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels)
  110. self.cls_convs.append(cls_convs)
  111. self.cls_fcs.append(cls_fcs)
  112. self.cls_last_dim.append(cls_last_dim)
  113. # add reg specific branch
  114. reg_convs, reg_fcs, reg_last_dim = self._add_conv_fc_branch(
  115. self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels)
  116. self.reg_convs.append(reg_convs)
  117. self.reg_fcs.append(reg_fcs)
  118. self.reg_last_dim.append(reg_last_dim)
  119. if self.num_shared_fcs == 0 and not self.with_avg_pool:
  120. if self.num_cls_fcs == 0:
  121. self.cls_last_dim *= self.roi_feat_area
  122. if self.num_reg_fcs == 0:
  123. self.reg_last_dim *= self.roi_feat_area
  124. if self.with_cls:
  125. if self.custom_cls_channels:
  126. cls_channels = self.loss_cls.get_cls_channels(
  127. self.num_classes)
  128. else:
  129. cls_channels = self.num_classes + 1
  130. cls_predictor_cfg_ = self.cls_predictor_cfg.copy() # deepcopy
  131. cls_predictor_cfg_.update(
  132. in_features=self.cls_last_dim[k],
  133. out_features=cls_channels)
  134. self.fc_cls.append(MODELS.build(cls_predictor_cfg_))
  135. if self.with_refine:
  136. self.fc_cls_ref.append(MODELS.build(cls_predictor_cfg_))
  137. if self.with_reg:
  138. out_dim_reg = (4 if self.reg_class_agnostic else 4 *
  139. self.num_classes)
  140. reg_predictor_cfg_ = self.reg_predictor_cfg.copy()
  141. reg_predictor_cfg_.update(
  142. in_features=self.reg_last_dim[k], out_features=out_dim_reg)
  143. self.fc_reg.append(MODELS.build(reg_predictor_cfg_))
  144. if self.with_refine:
  145. self.fc_reg_ref.append(MODELS.build(reg_predictor_cfg_))
  146. if init_cfg is None:
  147. # when init_cfg is None,
  148. # It has been set to
  149. # [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))],
  150. # [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))]
  151. # after `super(ConvFCBBoxHead, self).__init__()`
  152. # we only need to append additional configuration
  153. # for `shared_fcs`, `cls_fcs` and `reg_fcs`
  154. self.init_cfg += [
  155. dict(
  156. type='Xavier',
  157. distribution='uniform',
  158. override=[
  159. dict(name='shared_fcs'),
  160. dict(name='cls_fcs'),
  161. dict(name='reg_fcs')
  162. ])
  163. ]
  164. def _add_conv_fc_branch(self,
  165. num_branch_convs: int,
  166. num_branch_fcs: int,
  167. in_channels: int,
  168. is_shared: bool = False) -> tuple:
  169. """Add shared or separable branch.
  170. convs -> avg pool (optional) -> fcs
  171. """
  172. last_layer_dim = in_channels
  173. # add branch specific conv layers
  174. branch_convs = nn.ModuleList()
  175. if num_branch_convs > 0:
  176. for i in range(num_branch_convs):
  177. conv_in_channels = (
  178. last_layer_dim if i == 0 else self.conv_out_channels)
  179. branch_convs.append(
  180. ConvModule(
  181. conv_in_channels, self.conv_out_channels, 3,
  182. padding=1))
  183. last_layer_dim = self.conv_out_channels
  184. # add branch specific fc layers
  185. branch_fcs = nn.ModuleList()
  186. if num_branch_fcs > 0:
  187. # for shared branch, only consider self.with_avg_pool
  188. # for separated branches, also consider self.num_shared_fcs
  189. if (is_shared
  190. or self.num_shared_fcs == 0) and not self.with_avg_pool:
  191. last_layer_dim *= self.roi_feat_area
  192. for i in range(num_branch_fcs):
  193. fc_in_channels = (
  194. last_layer_dim if i == 0 else self.fc_out_channels)
  195. branch_fcs.append(
  196. nn.Linear(fc_in_channels, self.fc_out_channels))
  197. last_layer_dim = self.fc_out_channels
  198. return branch_convs, branch_fcs, last_layer_dim
  199. def forward(self, x: Tuple[Tensor]) -> tuple:
  200. """Forward features from the upstream network.
  201. Args:
  202. x (tuple[Tensor]): Features from the upstream network, each is
  203. a 4D-tensor.
  204. Returns:
  205. tuple: A tuple of classification scores and bbox prediction.
  206. - cls_score (Tensor): Classification scores for all scale
  207. levels, each is a 4D-tensor, the channels number is
  208. num_base_priors * num_classes.
  209. - bbox_pred (Tensor): Box energies / deltas for all scale
  210. levels, each is a 4D-tensor, the channels number is
  211. num_base_priors * 4.
  212. - cls_score_ref (Tensor): The cls_score after refine model.
  213. - bbox_pred_ref (Tensor): The bbox_pred after refine model.
  214. """
  215. # shared part
  216. if self.num_shared_convs > 0:
  217. for conv in self.shared_convs:
  218. x = conv(x)
  219. if self.num_shared_fcs > 0:
  220. if self.with_avg_pool:
  221. x = self.avg_pool(x)
  222. x = x.flatten(1)
  223. for fc in self.shared_fcs:
  224. x = self.relu(fc(x))
  225. x_cls = x
  226. x_reg = x
  227. # separate branches
  228. cls_score = list()
  229. bbox_pred = list()
  230. for k in range(self.num_instance):
  231. for conv in self.cls_convs[k]:
  232. x_cls = conv(x_cls)
  233. if x_cls.dim() > 2:
  234. if self.with_avg_pool:
  235. x_cls = self.avg_pool(x_cls)
  236. x_cls = x_cls.flatten(1)
  237. for fc in self.cls_fcs[k]:
  238. x_cls = self.relu(fc(x_cls))
  239. for conv in self.reg_convs[k]:
  240. x_reg = conv(x_reg)
  241. if x_reg.dim() > 2:
  242. if self.with_avg_pool:
  243. x_reg = self.avg_pool(x_reg)
  244. x_reg = x_reg.flatten(1)
  245. for fc in self.reg_fcs[k]:
  246. x_reg = self.relu(fc(x_reg))
  247. cls_score.append(self.fc_cls[k](x_cls) if self.with_cls else None)
  248. bbox_pred.append(self.fc_reg[k](x_reg) if self.with_reg else None)
  249. if self.with_refine:
  250. x_ref = x
  251. cls_score_ref = list()
  252. bbox_pred_ref = list()
  253. for k in range(self.num_instance):
  254. feat_ref = cls_score[k].softmax(dim=-1)
  255. feat_ref = torch.cat((bbox_pred[k], feat_ref[:, 1][:, None]),
  256. dim=1).repeat(1, 4)
  257. feat_ref = torch.cat((x_ref, feat_ref), dim=1)
  258. feat_ref = F.relu_(self.shared_fcs_ref(feat_ref))
  259. cls_score_ref.append(self.fc_cls_ref[k](feat_ref))
  260. bbox_pred_ref.append(self.fc_reg_ref[k](feat_ref))
  261. cls_score = torch.cat(cls_score, dim=1)
  262. bbox_pred = torch.cat(bbox_pred, dim=1)
  263. cls_score_ref = torch.cat(cls_score_ref, dim=1)
  264. bbox_pred_ref = torch.cat(bbox_pred_ref, dim=1)
  265. return cls_score, bbox_pred, cls_score_ref, bbox_pred_ref
  266. cls_score = torch.cat(cls_score, dim=1)
  267. bbox_pred = torch.cat(bbox_pred, dim=1)
  268. return cls_score, bbox_pred
  269. def get_targets(self,
  270. sampling_results: List[SamplingResult],
  271. rcnn_train_cfg: ConfigDict,
  272. concat: bool = True) -> tuple:
  273. """Calculate the ground truth for all samples in a batch according to
  274. the sampling_results.
  275. Almost the same as the implementation in bbox_head, we passed
  276. additional parameters pos_inds_list and neg_inds_list to
  277. `_get_targets_single` function.
  278. Args:
  279. sampling_results (List[obj:SamplingResult]): Assign results of
  280. all images in a batch after sampling.
  281. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
  282. concat (bool): Whether to concatenate the results of all
  283. the images in a single batch.
  284. Returns:
  285. Tuple[Tensor]: Ground truth for proposals in a single image.
  286. Containing the following list of Tensors:
  287. - labels (list[Tensor],Tensor): Gt_labels for all proposals in a
  288. batch, each tensor in list has shape (num_proposals,) when
  289. `concat=False`, otherwise just a single tensor has shape
  290. (num_all_proposals,).
  291. - label_weights (list[Tensor]): Labels_weights for
  292. all proposals in a batch, each tensor in list has shape
  293. (num_proposals,) when `concat=False`, otherwise just a single
  294. tensor has shape (num_all_proposals,).
  295. - bbox_targets (list[Tensor],Tensor): Regression target for all
  296. proposals in a batch, each tensor in list has shape
  297. (num_proposals, 4) when `concat=False`, otherwise just a single
  298. tensor has shape (num_all_proposals, 4), the last dimension 4
  299. represents [tl_x, tl_y, br_x, br_y].
  300. - bbox_weights (list[tensor],Tensor): Regression weights for
  301. all proposals in a batch, each tensor in list has shape
  302. (num_proposals, 4) when `concat=False`, otherwise just a
  303. single tensor has shape (num_all_proposals, 4).
  304. """
  305. labels = []
  306. bbox_targets = []
  307. bbox_weights = []
  308. label_weights = []
  309. for i in range(len(sampling_results)):
  310. sample_bboxes = torch.cat([
  311. sampling_results[i].pos_gt_bboxes,
  312. sampling_results[i].neg_gt_bboxes
  313. ])
  314. sample_priors = sampling_results[i].priors
  315. sample_priors = sample_priors.repeat(1, self.num_instance).reshape(
  316. -1, 4)
  317. sample_bboxes = sample_bboxes.reshape(-1, 4)
  318. if not self.reg_decoded_bbox:
  319. _bbox_targets = self.bbox_coder.encode(sample_priors,
  320. sample_bboxes)
  321. else:
  322. _bbox_targets = sample_priors
  323. _bbox_targets = _bbox_targets.reshape(-1, self.num_instance * 4)
  324. _bbox_weights = torch.ones(_bbox_targets.shape)
  325. _labels = torch.cat([
  326. sampling_results[i].pos_gt_labels,
  327. sampling_results[i].neg_gt_labels
  328. ])
  329. _labels_weights = torch.ones(_labels.shape)
  330. bbox_targets.append(_bbox_targets)
  331. bbox_weights.append(_bbox_weights)
  332. labels.append(_labels)
  333. label_weights.append(_labels_weights)
  334. if concat:
  335. labels = torch.cat(labels, 0)
  336. label_weights = torch.cat(label_weights, 0)
  337. bbox_targets = torch.cat(bbox_targets, 0)
  338. bbox_weights = torch.cat(bbox_weights, 0)
  339. return labels, label_weights, bbox_targets, bbox_weights
  340. def loss(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor,
  341. labels: Tensor, label_weights: Tensor, bbox_targets: Tensor,
  342. bbox_weights: Tensor, **kwargs) -> dict:
  343. """Calculate the loss based on the network predictions and targets.
  344. Args:
  345. cls_score (Tensor): Classification prediction results of all class,
  346. has shape (batch_size * num_proposals_single_image,
  347. (num_classes + 1) * k), k represents the number of prediction
  348. boxes generated by each proposal box.
  349. bbox_pred (Tensor): Regression prediction results, has shape
  350. (batch_size * num_proposals_single_image, 4 * k), the last
  351. dimension 4 represents [tl_x, tl_y, br_x, br_y].
  352. rois (Tensor): RoIs with the shape
  353. (batch_size * num_proposals_single_image, 5) where the first
  354. column indicates batch id of each RoI.
  355. labels (Tensor): Gt_labels for all proposals in a batch, has
  356. shape (batch_size * num_proposals_single_image, k).
  357. label_weights (Tensor): Labels_weights for all proposals in a
  358. batch, has shape (batch_size * num_proposals_single_image, k).
  359. bbox_targets (Tensor): Regression target for all proposals in a
  360. batch, has shape (batch_size * num_proposals_single_image,
  361. 4 * k), the last dimension 4 represents [tl_x, tl_y, br_x,
  362. br_y].
  363. bbox_weights (Tensor): Regression weights for all proposals in a
  364. batch, has shape (batch_size * num_proposals_single_image,
  365. 4 * k).
  366. Returns:
  367. dict: A dictionary of loss.
  368. """
  369. losses = dict()
  370. if bbox_pred.numel():
  371. loss_0 = self.emd_loss(bbox_pred[:, 0:4], cls_score[:, 0:2],
  372. bbox_pred[:, 4:8], cls_score[:, 2:4],
  373. bbox_targets, labels)
  374. loss_1 = self.emd_loss(bbox_pred[:, 4:8], cls_score[:, 2:4],
  375. bbox_pred[:, 0:4], cls_score[:, 0:2],
  376. bbox_targets, labels)
  377. loss = torch.cat([loss_0, loss_1], dim=1)
  378. _, min_indices = loss.min(dim=1)
  379. loss_emd = loss[torch.arange(loss.shape[0]), min_indices]
  380. loss_emd = loss_emd.mean()
  381. else:
  382. loss_emd = bbox_pred.sum()
  383. losses['loss_rcnn_emd'] = loss_emd
  384. return losses
  385. def emd_loss(self, bbox_pred_0: Tensor, cls_score_0: Tensor,
  386. bbox_pred_1: Tensor, cls_score_1: Tensor, targets: Tensor,
  387. labels: Tensor) -> Tensor:
  388. """Calculate the emd loss.
  389. Note:
  390. This implementation is modified from https://github.com/Purkialo/
  391. CrowdDet/blob/master/lib/det_oprs/loss_opr.py
  392. Args:
  393. bbox_pred_0 (Tensor): Part of regression prediction results, has
  394. shape (batch_size * num_proposals_single_image, 4), the last
  395. dimension 4 represents [tl_x, tl_y, br_x, br_y].
  396. cls_score_0 (Tensor): Part of classification prediction results,
  397. has shape (batch_size * num_proposals_single_image,
  398. (num_classes + 1)), where 1 represents the background.
  399. bbox_pred_1 (Tensor): The other part of regression prediction
  400. results, has shape (batch_size*num_proposals_single_image, 4).
  401. cls_score_1 (Tensor):The other part of classification prediction
  402. results, has shape (batch_size * num_proposals_single_image,
  403. (num_classes + 1)).
  404. targets (Tensor):Regression target for all proposals in a
  405. batch, has shape (batch_size * num_proposals_single_image,
  406. 4 * k), the last dimension 4 represents [tl_x, tl_y, br_x,
  407. br_y], k represents the number of prediction boxes generated
  408. by each proposal box.
  409. labels (Tensor): Gt_labels for all proposals in a batch, has
  410. shape (batch_size * num_proposals_single_image, k).
  411. Returns:
  412. torch.Tensor: The calculated loss.
  413. """
  414. bbox_pred = torch.cat([bbox_pred_0, bbox_pred_1],
  415. dim=1).reshape(-1, bbox_pred_0.shape[-1])
  416. cls_score = torch.cat([cls_score_0, cls_score_1],
  417. dim=1).reshape(-1, cls_score_0.shape[-1])
  418. targets = targets.reshape(-1, 4)
  419. labels = labels.long().flatten()
  420. # masks
  421. valid_masks = labels >= 0
  422. fg_masks = labels > 0
  423. # multiple class
  424. bbox_pred = bbox_pred.reshape(-1, self.num_classes, 4)
  425. fg_gt_classes = labels[fg_masks]
  426. bbox_pred = bbox_pred[fg_masks, fg_gt_classes - 1, :]
  427. # loss for regression
  428. loss_bbox = self.loss_bbox(bbox_pred, targets[fg_masks])
  429. loss_bbox = loss_bbox.sum(dim=1)
  430. # loss for classification
  431. labels = labels * valid_masks
  432. loss_cls = self.loss_cls(cls_score, labels)
  433. loss_cls[fg_masks] = loss_cls[fg_masks] + loss_bbox
  434. loss = loss_cls.reshape(-1, 2).sum(dim=1)
  435. return loss.reshape(-1, 1)
  436. def _predict_by_feat_single(
  437. self,
  438. roi: Tensor,
  439. cls_score: Tensor,
  440. bbox_pred: Tensor,
  441. img_meta: dict,
  442. rescale: bool = False,
  443. rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData:
  444. """Transform a single image's features extracted from the head into
  445. bbox results.
  446. Args:
  447. roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
  448. last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
  449. cls_score (Tensor): Box scores, has shape
  450. (num_boxes, num_classes + 1).
  451. bbox_pred (Tensor): Box energies / deltas. has shape
  452. (num_boxes, num_classes * 4).
  453. img_meta (dict): image information.
  454. rescale (bool): If True, return boxes in original image space.
  455. Defaults to False.
  456. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head.
  457. Defaults to None
  458. Returns:
  459. :obj:`InstanceData`: Detection results of each image.
  460. Each item usually contains following keys.
  461. - scores (Tensor): Classification scores, has a shape
  462. (num_instance, )
  463. - labels (Tensor): Labels of bboxes, has a shape
  464. (num_instances, ).
  465. - bboxes (Tensor): Has a shape (num_instances, 4),
  466. the last dimension 4 arrange as (x1, y1, x2, y2).
  467. """
  468. cls_score = cls_score.reshape(-1, self.num_classes + 1)
  469. bbox_pred = bbox_pred.reshape(-1, 4)
  470. roi = roi.repeat_interleave(self.num_instance, dim=0)
  471. results = InstanceData()
  472. if roi.shape[0] == 0:
  473. return empty_instances([img_meta],
  474. roi.device,
  475. task_type='bbox',
  476. instance_results=[results])[0]
  477. scores = cls_score.softmax(dim=-1) if cls_score is not None else None
  478. img_shape = img_meta['img_shape']
  479. bboxes = self.bbox_coder.decode(
  480. roi[..., 1:], bbox_pred, max_shape=img_shape)
  481. if rescale and bboxes.size(0) > 0:
  482. assert img_meta.get('scale_factor') is not None
  483. scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
  484. (1, 2))
  485. bboxes = (bboxes.view(bboxes.size(0), -1, 4) / scale_factor).view(
  486. bboxes.size()[0], -1)
  487. if rcnn_test_cfg is None:
  488. # This means that it is aug test.
  489. # It needs to return the raw results without nms.
  490. results.bboxes = bboxes
  491. results.scores = scores
  492. else:
  493. roi_idx = np.tile(
  494. np.arange(bboxes.shape[0] / self.num_instance)[:, None],
  495. (1, self.num_instance)).reshape(-1, 1)[:, 0]
  496. roi_idx = torch.from_numpy(roi_idx).to(bboxes.device).reshape(
  497. -1, 1)
  498. bboxes = torch.cat([bboxes, roi_idx], dim=1)
  499. det_bboxes, det_scores = self.set_nms(
  500. bboxes, scores[:, 1], rcnn_test_cfg.score_thr,
  501. rcnn_test_cfg.nms['iou_threshold'], rcnn_test_cfg.max_per_img)
  502. results.bboxes = det_bboxes[:, :-1]
  503. results.scores = det_scores
  504. results.labels = torch.zeros_like(det_scores)
  505. return results
  506. @staticmethod
  507. def set_nms(bboxes: Tensor,
  508. scores: Tensor,
  509. score_thr: float,
  510. iou_threshold: float,
  511. max_num: int = -1) -> Tuple[Tensor, Tensor]:
  512. """NMS for multi-instance prediction. Please refer to
  513. https://github.com/Purkialo/CrowdDet for more details.
  514. Args:
  515. bboxes (Tensor): predict bboxes.
  516. scores (Tensor): The score of each predict bbox.
  517. score_thr (float): bbox threshold, bboxes with scores lower than it
  518. will not be considered.
  519. iou_threshold (float): IoU threshold to be considered as
  520. conflicted.
  521. max_num (int, optional): if there are more than max_num bboxes
  522. after NMS, only top max_num will be kept. Default to -1.
  523. Returns:
  524. Tuple[Tensor, Tensor]: (bboxes, scores).
  525. """
  526. bboxes = bboxes[scores > score_thr]
  527. scores = scores[scores > score_thr]
  528. ordered_scores, order = scores.sort(descending=True)
  529. ordered_bboxes = bboxes[order]
  530. roi_idx = ordered_bboxes[:, -1]
  531. keep = torch.ones(len(ordered_bboxes)) == 1
  532. ruler = torch.arange(len(ordered_bboxes))
  533. while ruler.shape[0] > 0:
  534. basement = ruler[0]
  535. ruler = ruler[1:]
  536. idx = roi_idx[basement]
  537. # calculate the body overlap
  538. basement_bbox = ordered_bboxes[:, :4][basement].reshape(-1, 4)
  539. ruler_bbox = ordered_bboxes[:, :4][ruler].reshape(-1, 4)
  540. overlap = bbox_overlaps(basement_bbox, ruler_bbox)
  541. indices = torch.where(overlap > iou_threshold)[1]
  542. loc = torch.where(roi_idx[ruler][indices] == idx)
  543. # the mask won't change in the step
  544. mask = keep[ruler[indices][loc]]
  545. keep[ruler[indices]] = False
  546. keep[ruler[indices][loc][mask]] = True
  547. ruler[~keep[ruler]] = -1
  548. ruler = ruler[ruler > 0]
  549. keep = keep[order.sort()[1]]
  550. return bboxes[keep][:max_num, :], scores[keep][:max_num]