atss_vlfusion_head.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import math
  4. from typing import Callable, List, Optional, Sequence, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from mmcv.cnn import Scale
  9. from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d
  10. from mmengine.config import ConfigDict
  11. from mmengine.model import BaseModel
  12. from mmengine.structures import InstanceData
  13. from torch import Tensor
  14. try:
  15. from transformers import BertConfig
  16. except ImportError:
  17. BertConfig = None
  18. from mmdet.registry import MODELS
  19. from mmdet.structures.bbox import cat_boxes
  20. from mmdet.utils import InstanceList
  21. from ..utils import (BertEncoderLayer, VLFuse, filter_scores_and_topk,
  22. permute_and_flatten, select_single_mlvl)
  23. from ..utils.vlfuse_helper import MAX_CLAMP_VALUE
  24. from .atss_head import ATSSHead
  25. def convert_grounding_to_cls_scores(logits: Tensor,
  26. positive_maps: List[dict]) -> Tensor:
  27. """Convert logits to class scores."""
  28. assert len(positive_maps) == logits.shape[0] # batch size
  29. scores = torch.zeros(logits.shape[0], logits.shape[1],
  30. len(positive_maps[0])).to(logits.device)
  31. if positive_maps is not None:
  32. if all(x == positive_maps[0] for x in positive_maps):
  33. # only need to compute once
  34. positive_map = positive_maps[0]
  35. for label_j in positive_map:
  36. scores[:, :, label_j -
  37. 1] = logits[:, :,
  38. torch.LongTensor(positive_map[label_j]
  39. )].mean(-1)
  40. else:
  41. for i, positive_map in enumerate(positive_maps):
  42. for label_j in positive_map:
  43. scores[i, :, label_j - 1] = logits[
  44. i, :, torch.LongTensor(positive_map[label_j])].mean(-1)
  45. return scores
  46. class Conv3x3Norm(nn.Module):
  47. """Conv3x3 and norm."""
  48. def __init__(self,
  49. in_channels: int,
  50. out_channels: int,
  51. stride: int,
  52. groups: int = 1,
  53. use_dcn: bool = False,
  54. norm_type: Optional[Union[Sequence, str]] = None):
  55. super().__init__()
  56. if use_dcn:
  57. self.conv = ModulatedDeformConv2d(
  58. in_channels,
  59. out_channels,
  60. kernel_size=3,
  61. stride=stride,
  62. padding=1,
  63. groups=groups)
  64. else:
  65. self.conv = nn.Conv2d(
  66. in_channels,
  67. out_channels,
  68. kernel_size=3,
  69. stride=stride,
  70. padding=1,
  71. groups=groups)
  72. if isinstance(norm_type, Sequence):
  73. assert len(norm_type) == 2
  74. assert norm_type[0] == 'gn'
  75. gn_group = norm_type[1]
  76. norm_type = norm_type[0]
  77. if norm_type == 'bn':
  78. bn_op = nn.BatchNorm2d(out_channels)
  79. elif norm_type == 'gn':
  80. bn_op = nn.GroupNorm(
  81. num_groups=gn_group, num_channels=out_channels)
  82. if norm_type is not None:
  83. self.bn = bn_op
  84. else:
  85. self.bn = None
  86. def forward(self, x, **kwargs):
  87. x = self.conv(x, **kwargs)
  88. if self.bn:
  89. x = self.bn(x)
  90. return x
  91. class DyReLU(nn.Module):
  92. """Dynamic ReLU."""
  93. def __init__(self,
  94. in_channels: int,
  95. out_channels: int,
  96. expand_ratio: int = 4):
  97. super().__init__()
  98. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  99. self.expand_ratio = expand_ratio
  100. self.out_channels = out_channels
  101. self.fc = nn.Sequential(
  102. nn.Linear(in_channels, in_channels // expand_ratio),
  103. nn.ReLU(inplace=True),
  104. nn.Linear(in_channels // expand_ratio,
  105. out_channels * self.expand_ratio),
  106. nn.Hardsigmoid(inplace=True))
  107. def forward(self, x) -> Tensor:
  108. x_out = x
  109. b, c, h, w = x.size()
  110. x = self.avg_pool(x).view(b, c)
  111. x = self.fc(x).view(b, -1, 1, 1)
  112. a1, b1, a2, b2 = torch.split(x, self.out_channels, dim=1)
  113. a1 = (a1 - 0.5) * 2 + 1.0
  114. a2 = (a2 - 0.5) * 2
  115. b1 = b1 - 0.5
  116. b2 = b2 - 0.5
  117. out = torch.max(x_out * a1 + b1, x_out * a2 + b2)
  118. return out
  119. class DyConv(nn.Module):
  120. """Dynamic Convolution."""
  121. def __init__(self,
  122. conv_func: Callable,
  123. in_channels: int,
  124. out_channels: int,
  125. use_dyfuse: bool = True,
  126. use_dyrelu: bool = False,
  127. use_dcn: bool = False):
  128. super().__init__()
  129. self.dyconvs = nn.ModuleList()
  130. self.dyconvs.append(conv_func(in_channels, out_channels, 1))
  131. self.dyconvs.append(conv_func(in_channels, out_channels, 1))
  132. self.dyconvs.append(conv_func(in_channels, out_channels, 2))
  133. if use_dyfuse:
  134. self.attnconv = nn.Sequential(
  135. nn.AdaptiveAvgPool2d(1),
  136. nn.Conv2d(in_channels, 1, kernel_size=1),
  137. nn.ReLU(inplace=True))
  138. self.h_sigmoid = nn.Hardsigmoid(inplace=True)
  139. else:
  140. self.attnconv = None
  141. if use_dyrelu:
  142. self.relu = DyReLU(in_channels, out_channels)
  143. else:
  144. self.relu = nn.ReLU()
  145. if use_dcn:
  146. self.offset = nn.Conv2d(
  147. in_channels, 27, kernel_size=3, stride=1, padding=1)
  148. else:
  149. self.offset = None
  150. self.init_weights()
  151. def init_weights(self):
  152. for m in self.dyconvs.modules():
  153. if isinstance(m, nn.Conv2d):
  154. nn.init.normal_(m.weight.data, 0, 0.01)
  155. if m.bias is not None:
  156. m.bias.data.zero_()
  157. if self.attnconv is not None:
  158. for m in self.attnconv.modules():
  159. if isinstance(m, nn.Conv2d):
  160. nn.init.normal_(m.weight.data, 0, 0.01)
  161. if m.bias is not None:
  162. m.bias.data.zero_()
  163. def forward(self, inputs: dict) -> dict:
  164. visual_feats = inputs['visual']
  165. out_vis_feats = []
  166. for level, feature in enumerate(visual_feats):
  167. offset_conv_args = {}
  168. if self.offset is not None:
  169. offset_mask = self.offset(feature)
  170. offset = offset_mask[:, :18, :, :]
  171. mask = offset_mask[:, 18:, :, :].sigmoid()
  172. offset_conv_args = dict(offset=offset, mask=mask)
  173. temp_feats = [self.dyconvs[1](feature, **offset_conv_args)]
  174. if level > 0:
  175. temp_feats.append(self.dyconvs[2](visual_feats[level - 1],
  176. **offset_conv_args))
  177. if level < len(visual_feats) - 1:
  178. temp_feats.append(
  179. F.upsample_bilinear(
  180. self.dyconvs[0](visual_feats[level + 1],
  181. **offset_conv_args),
  182. size=[feature.size(2),
  183. feature.size(3)]))
  184. mean_feats = torch.mean(
  185. torch.stack(temp_feats), dim=0, keepdim=False)
  186. if self.attnconv is not None:
  187. attn_feat = []
  188. res_feat = []
  189. for feat in temp_feats:
  190. res_feat.append(feat)
  191. attn_feat.append(self.attnconv(feat))
  192. res_feat = torch.stack(res_feat)
  193. spa_pyr_attn = self.h_sigmoid(torch.stack(attn_feat))
  194. mean_feats = torch.mean(
  195. res_feat * spa_pyr_attn, dim=0, keepdim=False)
  196. out_vis_feats.append(mean_feats)
  197. out_vis_feats = [self.relu(item) for item in out_vis_feats]
  198. features_dict = {'visual': out_vis_feats, 'lang': inputs['lang']}
  199. return features_dict
  200. class VLFusionModule(BaseModel):
  201. """Visual-lang Fusion Module."""
  202. def __init__(self,
  203. in_channels: int,
  204. feat_channels: int,
  205. num_base_priors: int,
  206. early_fuse: bool = False,
  207. num_dyhead_blocks: int = 6,
  208. lang_model_name: str = 'bert-base-uncased',
  209. use_dyrelu: bool = True,
  210. use_dyfuse: bool = True,
  211. use_dcn: bool = True,
  212. use_checkpoint: bool = False,
  213. **kwargs) -> None:
  214. super().__init__(**kwargs)
  215. if BertConfig is None:
  216. raise RuntimeError(
  217. 'transformers is not installed, please install it by: '
  218. 'pip install transformers.')
  219. self.in_channels = in_channels
  220. self.feat_channels = feat_channels
  221. self.num_base_priors = num_base_priors
  222. self.early_fuse = early_fuse
  223. self.num_dyhead_blocks = num_dyhead_blocks
  224. self.use_dyrelu = use_dyrelu
  225. self.use_dyfuse = use_dyfuse
  226. self.use_dcn = use_dcn
  227. self.use_checkpoint = use_checkpoint
  228. self.lang_cfg = BertConfig.from_pretrained(lang_model_name)
  229. self.lang_dim = self.lang_cfg.hidden_size
  230. self._init_layers()
  231. def _init_layers(self) -> None:
  232. """Initialize layers of the model."""
  233. bias_value = -math.log((1 - 0.01) / 0.01)
  234. dyhead_tower = []
  235. for i in range(self.num_dyhead_blocks):
  236. if self.early_fuse:
  237. # cross-modality fusion
  238. dyhead_tower.append(VLFuse(use_checkpoint=self.use_checkpoint))
  239. # lang branch
  240. dyhead_tower.append(
  241. BertEncoderLayer(
  242. self.lang_cfg,
  243. clamp_min_for_underflow=True,
  244. clamp_max_for_overflow=True))
  245. # vision branch
  246. dyhead_tower.append(
  247. DyConv(
  248. lambda i, o, s: Conv3x3Norm(
  249. i, o, s, use_dcn=self.use_dcn, norm_type=['gn', 16]),
  250. self.in_channels if i == 0 else self.feat_channels,
  251. self.feat_channels,
  252. use_dyrelu=(self.use_dyrelu
  253. and self.in_channels == self.feat_channels)
  254. if i == 0 else self.use_dyrelu,
  255. use_dyfuse=(self.use_dyfuse
  256. and self.in_channels == self.feat_channels)
  257. if i == 0 else self.use_dyfuse,
  258. use_dcn=(self.use_dcn
  259. and self.in_channels == self.feat_channels)
  260. if i == 0 else self.use_dcn,
  261. ))
  262. self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower))
  263. self.bbox_pred = nn.Conv2d(
  264. self.feat_channels, self.num_base_priors * 4, kernel_size=1)
  265. self.centerness = nn.Conv2d(
  266. self.feat_channels, self.num_base_priors * 1, kernel_size=1)
  267. self.dot_product_projection_text = nn.Linear(
  268. self.lang_dim,
  269. self.num_base_priors * self.feat_channels,
  270. bias=True)
  271. self.log_scale = nn.Parameter(torch.Tensor([0.0]), requires_grad=True)
  272. self.bias_lang = nn.Parameter(
  273. torch.zeros(self.lang_dim), requires_grad=True)
  274. self.bias0 = nn.Parameter(
  275. torch.Tensor([bias_value]), requires_grad=True)
  276. self.scales = nn.ModuleList([Scale(1.0) for _ in range(5)])
  277. def forward(self, visual_feats: Tuple[Tensor],
  278. language_feats: dict) -> Tuple:
  279. feat_inputs = {'visual': visual_feats, 'lang': language_feats}
  280. dyhead_tower = self.dyhead_tower(feat_inputs)
  281. if self.early_fuse:
  282. embedding = dyhead_tower['lang']['hidden']
  283. else:
  284. embedding = language_feats['embedded']
  285. embedding = F.normalize(embedding, p=2, dim=-1)
  286. dot_product_proj_tokens = self.dot_product_projection_text(embedding /
  287. 2.0)
  288. dot_product_proj_tokens_bias = torch.matmul(
  289. embedding, self.bias_lang) + self.bias0
  290. bbox_preds = []
  291. centerness = []
  292. cls_logits = []
  293. for i, feature in enumerate(visual_feats):
  294. visual = dyhead_tower['visual'][i]
  295. B, C, H, W = visual.shape
  296. bbox_pred = self.scales[i](self.bbox_pred(visual))
  297. bbox_preds.append(bbox_pred)
  298. centerness.append(self.centerness(visual))
  299. dot_product_proj_queries = permute_and_flatten(
  300. visual, B, self.num_base_priors, C, H, W)
  301. bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(
  302. 1, self.num_base_priors, 1)
  303. dot_product_logit = (
  304. torch.matmul(dot_product_proj_queries,
  305. dot_product_proj_tokens.transpose(-1, -2)) /
  306. self.log_scale.exp()) + bias
  307. dot_product_logit = torch.clamp(
  308. dot_product_logit, max=MAX_CLAMP_VALUE)
  309. dot_product_logit = torch.clamp(
  310. dot_product_logit, min=-MAX_CLAMP_VALUE)
  311. cls_logits.append(dot_product_logit)
  312. return bbox_preds, centerness, cls_logits
  313. @MODELS.register_module()
  314. class ATSSVLFusionHead(ATSSHead):
  315. """ATSS head with visual-language fusion module.
  316. Args:
  317. early_fuse (bool): Whether to fuse visual and language features
  318. Defaults to False.
  319. use_checkpoint (bool): Whether to use checkpoint. Defaults to False.
  320. num_dyhead_blocks (int): Number of dynamic head blocks. Defaults to 6.
  321. lang_model_name (str): Name of the language model.
  322. Defaults to 'bert-base-uncased'.
  323. """
  324. def __init__(self,
  325. *args,
  326. early_fuse: bool = False,
  327. use_checkpoint: bool = False,
  328. num_dyhead_blocks: int = 6,
  329. lang_model_name: str = 'bert-base-uncased',
  330. **kwargs):
  331. super().__init__(*args, **kwargs)
  332. self.head = VLFusionModule(
  333. in_channels=self.in_channels,
  334. feat_channels=self.feat_channels,
  335. num_base_priors=self.num_base_priors,
  336. early_fuse=early_fuse,
  337. use_checkpoint=use_checkpoint,
  338. num_dyhead_blocks=num_dyhead_blocks,
  339. lang_model_name=lang_model_name)
  340. def _init_layers(self) -> None:
  341. """No need to initialize the ATSS head layer."""
  342. pass
  343. def forward(self, visual_feats: Tuple[Tensor],
  344. language_feats: dict) -> Tuple[Tensor]:
  345. """Forward function."""
  346. bbox_preds, centerness, cls_logits = self.head(visual_feats,
  347. language_feats)
  348. return bbox_preds, centerness, cls_logits
  349. def predict(self,
  350. visual_feats: Tuple[Tensor],
  351. language_feats: dict,
  352. batch_data_samples,
  353. rescale: bool = True):
  354. """Perform forward propagation of the detection head and predict
  355. detection results on the features of the upstream network.
  356. Args:
  357. visual_feats (tuple[Tensor]): Multi-level visual features from the
  358. upstream network, each is a 4D-tensor.
  359. language_feats (dict): Language features from the upstream network.
  360. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  361. Samples. It usually includes information such as
  362. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  363. rescale (bool, optional): Whether to rescale the results.
  364. Defaults to False.
  365. Returns:
  366. list[obj:`InstanceData`]: Detection results of each image
  367. after the post process.
  368. """
  369. batch_img_metas = [
  370. data_samples.metainfo for data_samples in batch_data_samples
  371. ]
  372. batch_token_positive_maps = [
  373. data_samples.token_positive_map
  374. for data_samples in batch_data_samples
  375. ]
  376. outs = self(visual_feats, language_feats)
  377. predictions = self.predict_by_feat(
  378. *outs,
  379. batch_img_metas=batch_img_metas,
  380. batch_token_positive_maps=batch_token_positive_maps,
  381. rescale=rescale)
  382. return predictions
  383. def predict_by_feat(self,
  384. bbox_preds: List[Tensor],
  385. score_factors: List[Tensor],
  386. cls_logits: List[Tensor],
  387. batch_img_metas: Optional[List[dict]] = None,
  388. batch_token_positive_maps: Optional[List[dict]] = None,
  389. cfg: Optional[ConfigDict] = None,
  390. rescale: bool = False,
  391. with_nms: bool = True) -> InstanceList:
  392. """Transform a batch of output features extracted from the head into
  393. bbox results.
  394. Note: When score_factors is not None, the cls_scores are
  395. usually multiplied by it then obtain the real score used in NMS,
  396. such as CenterNess in FCOS, IoU branch in ATSS.
  397. Args:
  398. bbox_preds (list[Tensor]): Box energies / deltas for all
  399. scale levels, each is a 4D-tensor, has shape
  400. (batch_size, num_priors * 4, H, W).
  401. score_factors (list[Tensor], optional): Score factor for
  402. all scale level, each is a 4D-tensor, has shape
  403. (batch_size, num_priors * 1, H, W). Defaults to None.
  404. cls_logits (list[Tensor]): Classification scores for all
  405. scale levels, each is a 4D-tensor, has shape
  406. (batch_size, num_priors * num_classes, H, W).
  407. batch_img_metas (list[dict], Optional): Batch image meta info.
  408. Defaults to None.
  409. batch_token_positive_maps (list[dict], Optional): Batch token
  410. positive map. Defaults to None.
  411. cfg (ConfigDict, optional): Test / postprocessing
  412. configuration, if None, test_cfg would be used.
  413. Defaults to None.
  414. rescale (bool): If True, return boxes in original image space.
  415. Defaults to False.
  416. with_nms (bool): If True, do nms before return boxes.
  417. Defaults to True.
  418. Returns:
  419. list[:obj:`InstanceData`]: Object detection results of each image
  420. after the post process. Each item usually contains following keys.
  421. - scores (Tensor): Classification scores, has a shape
  422. (num_instance, )
  423. - labels (Tensor): Labels of bboxes, has a shape
  424. (num_instances, ).
  425. - bboxes (Tensor): Has a shape (num_instances, 4),
  426. the last dimension 4 arrange as (x1, y1, x2, y2).
  427. """
  428. assert len(bbox_preds) == len(score_factors)
  429. num_levels = len(bbox_preds)
  430. featmap_sizes = [bbox_preds[i].shape[-2:] for i in range(num_levels)]
  431. mlvl_priors = self.prior_generator.grid_priors(
  432. featmap_sizes,
  433. dtype=bbox_preds[0].dtype,
  434. device=bbox_preds[0].device)
  435. result_list = []
  436. for img_id in range(len(batch_img_metas)):
  437. img_meta = batch_img_metas[img_id]
  438. token_positive_maps = batch_token_positive_maps[img_id]
  439. bbox_pred_list = select_single_mlvl(
  440. bbox_preds, img_id, detach=True)
  441. score_factor_list = select_single_mlvl(
  442. score_factors, img_id, detach=True)
  443. cls_logit_list = select_single_mlvl(
  444. cls_logits, img_id, detach=True)
  445. results = self._predict_by_feat_single(
  446. bbox_pred_list=bbox_pred_list,
  447. score_factor_list=score_factor_list,
  448. cls_logit_list=cls_logit_list,
  449. mlvl_priors=mlvl_priors,
  450. token_positive_maps=token_positive_maps,
  451. img_meta=img_meta,
  452. cfg=cfg,
  453. rescale=rescale,
  454. with_nms=with_nms)
  455. result_list.append(results)
  456. return result_list
  457. def _predict_by_feat_single(self,
  458. bbox_pred_list: List[Tensor],
  459. score_factor_list: List[Tensor],
  460. cls_logit_list: List[Tensor],
  461. mlvl_priors: List[Tensor],
  462. token_positive_maps: dict,
  463. img_meta: dict,
  464. cfg: ConfigDict,
  465. rescale: bool = True,
  466. with_nms: bool = True) -> InstanceData:
  467. """Transform a single image's features extracted from the head into
  468. bbox results.
  469. Args:
  470. bbox_pred_list (list[Tensor]): Box energies / deltas from
  471. all scale levels of a single image, each item has shape
  472. (num_priors * 4, H, W).
  473. score_factor_list (list[Tensor]): Score factor from all scale
  474. levels of a single image, each item has shape
  475. (num_priors * 1, H, W).
  476. cls_logit_list (list[Tensor]): Box scores from all scale
  477. levels of a single image, each item has shape
  478. (num_priors * num_classes, H, W).
  479. mlvl_priors (list[Tensor]): Each element in the list is
  480. the priors of a single level in feature pyramid. In all
  481. anchor-based methods, it has shape (num_priors, 4). In
  482. all anchor-free methods, it has shape (num_priors, 2)
  483. when `with_stride=True`, otherwise it still has shape
  484. (num_priors, 4).
  485. token_positive_maps (dict): Token positive map.
  486. img_meta (dict): Image meta info.
  487. cfg (mmengine.Config): Test / postprocessing configuration,
  488. if None, test_cfg would be used.
  489. rescale (bool): If True, return boxes in original image space.
  490. Defaults to False.
  491. with_nms (bool): If True, do nms before return boxes.
  492. Defaults to True.
  493. Returns:
  494. :obj:`InstanceData`: Detection results of each image
  495. after the post process.
  496. Each item usually contains following keys.
  497. - scores (Tensor): Classification scores, has a shape
  498. (num_instance, )
  499. - labels (Tensor): Labels of bboxes, has a shape
  500. (num_instances, ).
  501. - bboxes (Tensor): Has a shape (num_instances, 4),
  502. the last dimension 4 arrange as (x1, y1, x2, y2).
  503. """
  504. cfg = self.test_cfg if cfg is None else cfg
  505. cfg = copy.deepcopy(cfg)
  506. img_shape = img_meta['img_shape']
  507. nms_pre = cfg.get('nms_pre', -1)
  508. score_thr = cfg.get('score_thr', 0)
  509. mlvl_bbox_preds = []
  510. mlvl_valid_priors = []
  511. mlvl_scores = []
  512. mlvl_labels = []
  513. for level_idx, (bbox_pred, score_factor, cls_logit, priors) in \
  514. enumerate(zip(bbox_pred_list,
  515. score_factor_list, cls_logit_list, mlvl_priors)):
  516. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(
  517. -1, self.bbox_coder.encode_size)
  518. score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid()
  519. scores = convert_grounding_to_cls_scores(
  520. logits=cls_logit.sigmoid()[None],
  521. positive_maps=[token_positive_maps])[0]
  522. results = filter_scores_and_topk(
  523. scores, score_thr, nms_pre,
  524. dict(bbox_pred=bbox_pred, priors=priors))
  525. scores, labels, keep_idxs, filtered_results = results
  526. bbox_pred = filtered_results['bbox_pred']
  527. priors = filtered_results['priors']
  528. score_factor = score_factor[keep_idxs]
  529. scores = torch.sqrt(scores * score_factor)
  530. mlvl_bbox_preds.append(bbox_pred)
  531. mlvl_valid_priors.append(priors)
  532. mlvl_scores.append(scores)
  533. mlvl_labels.append(labels)
  534. bbox_pred = torch.cat(mlvl_bbox_preds)
  535. priors = cat_boxes(mlvl_valid_priors)
  536. bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)
  537. results = InstanceData()
  538. results.bboxes = bboxes
  539. results.scores = torch.cat(mlvl_scores)
  540. results.labels = torch.cat(mlvl_labels)
  541. predictions = self._bbox_post_process(
  542. results=results,
  543. cfg=cfg,
  544. rescale=rescale,
  545. with_nms=with_nms,
  546. img_meta=img_meta)
  547. if len(predictions) > 0:
  548. # Note: GLIP adopts a very strange bbox decoder logic,
  549. # and if 1 is not added here, it will not align with
  550. # the official mAP.
  551. predictions.bboxes[:, 2:] = predictions.bboxes[:, 2:] + 1
  552. return predictions