123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- import math
- from typing import Callable, List, Optional, Sequence, Tuple, Union
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import Scale
- from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d
- from mmengine.config import ConfigDict
- from mmengine.model import BaseModel
- from mmengine.structures import InstanceData
- from torch import Tensor
- try:
- from transformers import BertConfig
- except ImportError:
- BertConfig = None
- from mmdet.registry import MODELS
- from mmdet.structures.bbox import cat_boxes
- from mmdet.utils import InstanceList
- from ..utils import (BertEncoderLayer, VLFuse, filter_scores_and_topk,
- permute_and_flatten, select_single_mlvl)
- from ..utils.vlfuse_helper import MAX_CLAMP_VALUE
- from .atss_head import ATSSHead
- def convert_grounding_to_cls_scores(logits: Tensor,
- positive_maps: List[dict]) -> Tensor:
- """Convert logits to class scores."""
- assert len(positive_maps) == logits.shape[0] # batch size
- scores = torch.zeros(logits.shape[0], logits.shape[1],
- len(positive_maps[0])).to(logits.device)
- if positive_maps is not None:
- if all(x == positive_maps[0] for x in positive_maps):
- # only need to compute once
- positive_map = positive_maps[0]
- for label_j in positive_map:
- scores[:, :, label_j -
- 1] = logits[:, :,
- torch.LongTensor(positive_map[label_j]
- )].mean(-1)
- else:
- for i, positive_map in enumerate(positive_maps):
- for label_j in positive_map:
- scores[i, :, label_j - 1] = logits[
- i, :, torch.LongTensor(positive_map[label_j])].mean(-1)
- return scores
- class Conv3x3Norm(nn.Module):
- """Conv3x3 and norm."""
- def __init__(self,
- in_channels: int,
- out_channels: int,
- stride: int,
- groups: int = 1,
- use_dcn: bool = False,
- norm_type: Optional[Union[Sequence, str]] = None):
- super().__init__()
- if use_dcn:
- self.conv = ModulatedDeformConv2d(
- in_channels,
- out_channels,
- kernel_size=3,
- stride=stride,
- padding=1,
- groups=groups)
- else:
- self.conv = nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size=3,
- stride=stride,
- padding=1,
- groups=groups)
- if isinstance(norm_type, Sequence):
- assert len(norm_type) == 2
- assert norm_type[0] == 'gn'
- gn_group = norm_type[1]
- norm_type = norm_type[0]
- if norm_type == 'bn':
- bn_op = nn.BatchNorm2d(out_channels)
- elif norm_type == 'gn':
- bn_op = nn.GroupNorm(
- num_groups=gn_group, num_channels=out_channels)
- if norm_type is not None:
- self.bn = bn_op
- else:
- self.bn = None
- def forward(self, x, **kwargs):
- x = self.conv(x, **kwargs)
- if self.bn:
- x = self.bn(x)
- return x
- class DyReLU(nn.Module):
- """Dynamic ReLU."""
- def __init__(self,
- in_channels: int,
- out_channels: int,
- expand_ratio: int = 4):
- super().__init__()
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.expand_ratio = expand_ratio
- self.out_channels = out_channels
- self.fc = nn.Sequential(
- nn.Linear(in_channels, in_channels // expand_ratio),
- nn.ReLU(inplace=True),
- nn.Linear(in_channels // expand_ratio,
- out_channels * self.expand_ratio),
- nn.Hardsigmoid(inplace=True))
- def forward(self, x) -> Tensor:
- x_out = x
- b, c, h, w = x.size()
- x = self.avg_pool(x).view(b, c)
- x = self.fc(x).view(b, -1, 1, 1)
- a1, b1, a2, b2 = torch.split(x, self.out_channels, dim=1)
- a1 = (a1 - 0.5) * 2 + 1.0
- a2 = (a2 - 0.5) * 2
- b1 = b1 - 0.5
- b2 = b2 - 0.5
- out = torch.max(x_out * a1 + b1, x_out * a2 + b2)
- return out
- class DyConv(nn.Module):
- """Dynamic Convolution."""
- def __init__(self,
- conv_func: Callable,
- in_channels: int,
- out_channels: int,
- use_dyfuse: bool = True,
- use_dyrelu: bool = False,
- use_dcn: bool = False):
- super().__init__()
- self.dyconvs = nn.ModuleList()
- self.dyconvs.append(conv_func(in_channels, out_channels, 1))
- self.dyconvs.append(conv_func(in_channels, out_channels, 1))
- self.dyconvs.append(conv_func(in_channels, out_channels, 2))
- if use_dyfuse:
- self.attnconv = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(in_channels, 1, kernel_size=1),
- nn.ReLU(inplace=True))
- self.h_sigmoid = nn.Hardsigmoid(inplace=True)
- else:
- self.attnconv = None
- if use_dyrelu:
- self.relu = DyReLU(in_channels, out_channels)
- else:
- self.relu = nn.ReLU()
- if use_dcn:
- self.offset = nn.Conv2d(
- in_channels, 27, kernel_size=3, stride=1, padding=1)
- else:
- self.offset = None
- self.init_weights()
- def init_weights(self):
- for m in self.dyconvs.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.normal_(m.weight.data, 0, 0.01)
- if m.bias is not None:
- m.bias.data.zero_()
- if self.attnconv is not None:
- for m in self.attnconv.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.normal_(m.weight.data, 0, 0.01)
- if m.bias is not None:
- m.bias.data.zero_()
- def forward(self, inputs: dict) -> dict:
- visual_feats = inputs['visual']
- out_vis_feats = []
- for level, feature in enumerate(visual_feats):
- offset_conv_args = {}
- if self.offset is not None:
- offset_mask = self.offset(feature)
- offset = offset_mask[:, :18, :, :]
- mask = offset_mask[:, 18:, :, :].sigmoid()
- offset_conv_args = dict(offset=offset, mask=mask)
- temp_feats = [self.dyconvs[1](feature, **offset_conv_args)]
- if level > 0:
- temp_feats.append(self.dyconvs[2](visual_feats[level - 1],
- **offset_conv_args))
- if level < len(visual_feats) - 1:
- temp_feats.append(
- F.upsample_bilinear(
- self.dyconvs[0](visual_feats[level + 1],
- **offset_conv_args),
- size=[feature.size(2),
- feature.size(3)]))
- mean_feats = torch.mean(
- torch.stack(temp_feats), dim=0, keepdim=False)
- if self.attnconv is not None:
- attn_feat = []
- res_feat = []
- for feat in temp_feats:
- res_feat.append(feat)
- attn_feat.append(self.attnconv(feat))
- res_feat = torch.stack(res_feat)
- spa_pyr_attn = self.h_sigmoid(torch.stack(attn_feat))
- mean_feats = torch.mean(
- res_feat * spa_pyr_attn, dim=0, keepdim=False)
- out_vis_feats.append(mean_feats)
- out_vis_feats = [self.relu(item) for item in out_vis_feats]
- features_dict = {'visual': out_vis_feats, 'lang': inputs['lang']}
- return features_dict
- class VLFusionModule(BaseModel):
- """Visual-lang Fusion Module."""
- def __init__(self,
- in_channels: int,
- feat_channels: int,
- num_base_priors: int,
- early_fuse: bool = False,
- num_dyhead_blocks: int = 6,
- lang_model_name: str = 'bert-base-uncased',
- use_dyrelu: bool = True,
- use_dyfuse: bool = True,
- use_dcn: bool = True,
- use_checkpoint: bool = False,
- **kwargs) -> None:
- super().__init__(**kwargs)
- if BertConfig is None:
- raise RuntimeError(
- 'transformers is not installed, please install it by: '
- 'pip install transformers.')
- self.in_channels = in_channels
- self.feat_channels = feat_channels
- self.num_base_priors = num_base_priors
- self.early_fuse = early_fuse
- self.num_dyhead_blocks = num_dyhead_blocks
- self.use_dyrelu = use_dyrelu
- self.use_dyfuse = use_dyfuse
- self.use_dcn = use_dcn
- self.use_checkpoint = use_checkpoint
- self.lang_cfg = BertConfig.from_pretrained(lang_model_name)
- self.lang_dim = self.lang_cfg.hidden_size
- self._init_layers()
- def _init_layers(self) -> None:
- """Initialize layers of the model."""
- bias_value = -math.log((1 - 0.01) / 0.01)
- dyhead_tower = []
- for i in range(self.num_dyhead_blocks):
- if self.early_fuse:
- # cross-modality fusion
- dyhead_tower.append(VLFuse(use_checkpoint=self.use_checkpoint))
- # lang branch
- dyhead_tower.append(
- BertEncoderLayer(
- self.lang_cfg,
- clamp_min_for_underflow=True,
- clamp_max_for_overflow=True))
- # vision branch
- dyhead_tower.append(
- DyConv(
- lambda i, o, s: Conv3x3Norm(
- i, o, s, use_dcn=self.use_dcn, norm_type=['gn', 16]),
- self.in_channels if i == 0 else self.feat_channels,
- self.feat_channels,
- use_dyrelu=(self.use_dyrelu
- and self.in_channels == self.feat_channels)
- if i == 0 else self.use_dyrelu,
- use_dyfuse=(self.use_dyfuse
- and self.in_channels == self.feat_channels)
- if i == 0 else self.use_dyfuse,
- use_dcn=(self.use_dcn
- and self.in_channels == self.feat_channels)
- if i == 0 else self.use_dcn,
- ))
- self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower))
- self.bbox_pred = nn.Conv2d(
- self.feat_channels, self.num_base_priors * 4, kernel_size=1)
- self.centerness = nn.Conv2d(
- self.feat_channels, self.num_base_priors * 1, kernel_size=1)
- self.dot_product_projection_text = nn.Linear(
- self.lang_dim,
- self.num_base_priors * self.feat_channels,
- bias=True)
- self.log_scale = nn.Parameter(torch.Tensor([0.0]), requires_grad=True)
- self.bias_lang = nn.Parameter(
- torch.zeros(self.lang_dim), requires_grad=True)
- self.bias0 = nn.Parameter(
- torch.Tensor([bias_value]), requires_grad=True)
- self.scales = nn.ModuleList([Scale(1.0) for _ in range(5)])
- def forward(self, visual_feats: Tuple[Tensor],
- language_feats: dict) -> Tuple:
- feat_inputs = {'visual': visual_feats, 'lang': language_feats}
- dyhead_tower = self.dyhead_tower(feat_inputs)
- if self.early_fuse:
- embedding = dyhead_tower['lang']['hidden']
- else:
- embedding = language_feats['embedded']
- embedding = F.normalize(embedding, p=2, dim=-1)
- dot_product_proj_tokens = self.dot_product_projection_text(embedding /
- 2.0)
- dot_product_proj_tokens_bias = torch.matmul(
- embedding, self.bias_lang) + self.bias0
- bbox_preds = []
- centerness = []
- cls_logits = []
- for i, feature in enumerate(visual_feats):
- visual = dyhead_tower['visual'][i]
- B, C, H, W = visual.shape
- bbox_pred = self.scales[i](self.bbox_pred(visual))
- bbox_preds.append(bbox_pred)
- centerness.append(self.centerness(visual))
- dot_product_proj_queries = permute_and_flatten(
- visual, B, self.num_base_priors, C, H, W)
- bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(
- 1, self.num_base_priors, 1)
- dot_product_logit = (
- torch.matmul(dot_product_proj_queries,
- dot_product_proj_tokens.transpose(-1, -2)) /
- self.log_scale.exp()) + bias
- dot_product_logit = torch.clamp(
- dot_product_logit, max=MAX_CLAMP_VALUE)
- dot_product_logit = torch.clamp(
- dot_product_logit, min=-MAX_CLAMP_VALUE)
- cls_logits.append(dot_product_logit)
- return bbox_preds, centerness, cls_logits
- @MODELS.register_module()
- class ATSSVLFusionHead(ATSSHead):
- """ATSS head with visual-language fusion module.
- Args:
- early_fuse (bool): Whether to fuse visual and language features
- Defaults to False.
- use_checkpoint (bool): Whether to use checkpoint. Defaults to False.
- num_dyhead_blocks (int): Number of dynamic head blocks. Defaults to 6.
- lang_model_name (str): Name of the language model.
- Defaults to 'bert-base-uncased'.
- """
- def __init__(self,
- *args,
- early_fuse: bool = False,
- use_checkpoint: bool = False,
- num_dyhead_blocks: int = 6,
- lang_model_name: str = 'bert-base-uncased',
- **kwargs):
- super().__init__(*args, **kwargs)
- self.head = VLFusionModule(
- in_channels=self.in_channels,
- feat_channels=self.feat_channels,
- num_base_priors=self.num_base_priors,
- early_fuse=early_fuse,
- use_checkpoint=use_checkpoint,
- num_dyhead_blocks=num_dyhead_blocks,
- lang_model_name=lang_model_name)
- def _init_layers(self) -> None:
- """No need to initialize the ATSS head layer."""
- pass
- def forward(self, visual_feats: Tuple[Tensor],
- language_feats: dict) -> Tuple[Tensor]:
- """Forward function."""
- bbox_preds, centerness, cls_logits = self.head(visual_feats,
- language_feats)
- return bbox_preds, centerness, cls_logits
- def predict(self,
- visual_feats: Tuple[Tensor],
- language_feats: dict,
- batch_data_samples,
- rescale: bool = True):
- """Perform forward propagation of the detection head and predict
- detection results on the features of the upstream network.
- Args:
- visual_feats (tuple[Tensor]): Multi-level visual features from the
- upstream network, each is a 4D-tensor.
- language_feats (dict): Language features from the upstream network.
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- rescale (bool, optional): Whether to rescale the results.
- Defaults to False.
- Returns:
- list[obj:`InstanceData`]: Detection results of each image
- after the post process.
- """
- batch_img_metas = [
- data_samples.metainfo for data_samples in batch_data_samples
- ]
- batch_token_positive_maps = [
- data_samples.token_positive_map
- for data_samples in batch_data_samples
- ]
- outs = self(visual_feats, language_feats)
- predictions = self.predict_by_feat(
- *outs,
- batch_img_metas=batch_img_metas,
- batch_token_positive_maps=batch_token_positive_maps,
- rescale=rescale)
- return predictions
- def predict_by_feat(self,
- bbox_preds: List[Tensor],
- score_factors: List[Tensor],
- cls_logits: List[Tensor],
- batch_img_metas: Optional[List[dict]] = None,
- batch_token_positive_maps: Optional[List[dict]] = None,
- cfg: Optional[ConfigDict] = None,
- rescale: bool = False,
- with_nms: bool = True) -> InstanceList:
- """Transform a batch of output features extracted from the head into
- bbox results.
- Note: When score_factors is not None, the cls_scores are
- usually multiplied by it then obtain the real score used in NMS,
- such as CenterNess in FCOS, IoU branch in ATSS.
- Args:
- bbox_preds (list[Tensor]): Box energies / deltas for all
- scale levels, each is a 4D-tensor, has shape
- (batch_size, num_priors * 4, H, W).
- score_factors (list[Tensor], optional): Score factor for
- all scale level, each is a 4D-tensor, has shape
- (batch_size, num_priors * 1, H, W). Defaults to None.
- cls_logits (list[Tensor]): Classification scores for all
- scale levels, each is a 4D-tensor, has shape
- (batch_size, num_priors * num_classes, H, W).
- batch_img_metas (list[dict], Optional): Batch image meta info.
- Defaults to None.
- batch_token_positive_maps (list[dict], Optional): Batch token
- positive map. Defaults to None.
- cfg (ConfigDict, optional): Test / postprocessing
- configuration, if None, test_cfg would be used.
- Defaults to None.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- with_nms (bool): If True, do nms before return boxes.
- Defaults to True.
- Returns:
- list[:obj:`InstanceData`]: Object detection results of each image
- after the post process. Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- assert len(bbox_preds) == len(score_factors)
- num_levels = len(bbox_preds)
- featmap_sizes = [bbox_preds[i].shape[-2:] for i in range(num_levels)]
- mlvl_priors = self.prior_generator.grid_priors(
- featmap_sizes,
- dtype=bbox_preds[0].dtype,
- device=bbox_preds[0].device)
- result_list = []
- for img_id in range(len(batch_img_metas)):
- img_meta = batch_img_metas[img_id]
- token_positive_maps = batch_token_positive_maps[img_id]
- bbox_pred_list = select_single_mlvl(
- bbox_preds, img_id, detach=True)
- score_factor_list = select_single_mlvl(
- score_factors, img_id, detach=True)
- cls_logit_list = select_single_mlvl(
- cls_logits, img_id, detach=True)
- results = self._predict_by_feat_single(
- bbox_pred_list=bbox_pred_list,
- score_factor_list=score_factor_list,
- cls_logit_list=cls_logit_list,
- mlvl_priors=mlvl_priors,
- token_positive_maps=token_positive_maps,
- img_meta=img_meta,
- cfg=cfg,
- rescale=rescale,
- with_nms=with_nms)
- result_list.append(results)
- return result_list
- def _predict_by_feat_single(self,
- bbox_pred_list: List[Tensor],
- score_factor_list: List[Tensor],
- cls_logit_list: List[Tensor],
- mlvl_priors: List[Tensor],
- token_positive_maps: dict,
- img_meta: dict,
- cfg: ConfigDict,
- rescale: bool = True,
- with_nms: bool = True) -> InstanceData:
- """Transform a single image's features extracted from the head into
- bbox results.
- Args:
- bbox_pred_list (list[Tensor]): Box energies / deltas from
- all scale levels of a single image, each item has shape
- (num_priors * 4, H, W).
- score_factor_list (list[Tensor]): Score factor from all scale
- levels of a single image, each item has shape
- (num_priors * 1, H, W).
- cls_logit_list (list[Tensor]): Box scores from all scale
- levels of a single image, each item has shape
- (num_priors * num_classes, H, W).
- mlvl_priors (list[Tensor]): Each element in the list is
- the priors of a single level in feature pyramid. In all
- anchor-based methods, it has shape (num_priors, 4). In
- all anchor-free methods, it has shape (num_priors, 2)
- when `with_stride=True`, otherwise it still has shape
- (num_priors, 4).
- token_positive_maps (dict): Token positive map.
- img_meta (dict): Image meta info.
- cfg (mmengine.Config): Test / postprocessing configuration,
- if None, test_cfg would be used.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- with_nms (bool): If True, do nms before return boxes.
- Defaults to True.
- Returns:
- :obj:`InstanceData`: Detection results of each image
- after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- cfg = self.test_cfg if cfg is None else cfg
- cfg = copy.deepcopy(cfg)
- img_shape = img_meta['img_shape']
- nms_pre = cfg.get('nms_pre', -1)
- score_thr = cfg.get('score_thr', 0)
- mlvl_bbox_preds = []
- mlvl_valid_priors = []
- mlvl_scores = []
- mlvl_labels = []
- for level_idx, (bbox_pred, score_factor, cls_logit, priors) in \
- enumerate(zip(bbox_pred_list,
- score_factor_list, cls_logit_list, mlvl_priors)):
- bbox_pred = bbox_pred.permute(1, 2, 0).reshape(
- -1, self.bbox_coder.encode_size)
- score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid()
- scores = convert_grounding_to_cls_scores(
- logits=cls_logit.sigmoid()[None],
- positive_maps=[token_positive_maps])[0]
- results = filter_scores_and_topk(
- scores, score_thr, nms_pre,
- dict(bbox_pred=bbox_pred, priors=priors))
- scores, labels, keep_idxs, filtered_results = results
- bbox_pred = filtered_results['bbox_pred']
- priors = filtered_results['priors']
- score_factor = score_factor[keep_idxs]
- scores = torch.sqrt(scores * score_factor)
- mlvl_bbox_preds.append(bbox_pred)
- mlvl_valid_priors.append(priors)
- mlvl_scores.append(scores)
- mlvl_labels.append(labels)
- bbox_pred = torch.cat(mlvl_bbox_preds)
- priors = cat_boxes(mlvl_valid_priors)
- bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)
- results = InstanceData()
- results.bboxes = bboxes
- results.scores = torch.cat(mlvl_scores)
- results.labels = torch.cat(mlvl_labels)
- predictions = self._bbox_post_process(
- results=results,
- cfg=cfg,
- rescale=rescale,
- with_nms=with_nms,
- img_meta=img_meta)
- if len(predictions) > 0:
- # Note: GLIP adopts a very strange bbox decoder logic,
- # and if 1 is not added here, it will not align with
- # the official mAP.
- predictions.bboxes[:, 2:] = predictions.bboxes[:, 2:] + 1
- return predictions
|