123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Tuple, Union
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import Conv2d, ConvModule
- from mmengine.model import BaseModule, ModuleList, caffe2_xavier_init
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.utils import ConfigType, OptMultiConfig
- from .positional_encoding import SinePositionalEncoding
- from .transformer import DetrTransformerEncoder
- @MODELS.register_module()
- class PixelDecoder(BaseModule):
- """Pixel decoder with a structure like fpn.
- Args:
- in_channels (list[int] | tuple[int]): Number of channels in the
- input feature maps.
- feat_channels (int): Number channels for feature.
- out_channels (int): Number channels for output.
- norm_cfg (:obj:`ConfigDict` or dict): Config for normalization.
- Defaults to dict(type='GN', num_groups=32).
- act_cfg (:obj:`ConfigDict` or dict): Config for activation.
- Defaults to dict(type='ReLU').
- encoder (:obj:`ConfigDict` or dict): Config for transorformer
- encoder.Defaults to None.
- positional_encoding (:obj:`ConfigDict` or dict): Config for
- transformer encoder position encoding. Defaults to
- dict(type='SinePositionalEncoding', num_feats=128,
- normalize=True).
- init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
- dict], optional): Initialization config dict. Defaults to None.
- """
- def __init__(self,
- in_channels: Union[List[int], Tuple[int]],
- feat_channels: int,
- out_channels: int,
- norm_cfg: ConfigType = dict(type='GN', num_groups=32),
- act_cfg: ConfigType = dict(type='ReLU'),
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(init_cfg=init_cfg)
- self.in_channels = in_channels
- self.num_inputs = len(in_channels)
- self.lateral_convs = ModuleList()
- self.output_convs = ModuleList()
- self.use_bias = norm_cfg is None
- for i in range(0, self.num_inputs - 1):
- lateral_conv = ConvModule(
- in_channels[i],
- feat_channels,
- kernel_size=1,
- bias=self.use_bias,
- norm_cfg=norm_cfg,
- act_cfg=None)
- output_conv = ConvModule(
- feat_channels,
- feat_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=self.use_bias,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.lateral_convs.append(lateral_conv)
- self.output_convs.append(output_conv)
- self.last_feat_conv = ConvModule(
- in_channels[-1],
- feat_channels,
- kernel_size=3,
- padding=1,
- stride=1,
- bias=self.use_bias,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.mask_feature = Conv2d(
- feat_channels, out_channels, kernel_size=3, stride=1, padding=1)
- def init_weights(self) -> None:
- """Initialize weights."""
- for i in range(0, self.num_inputs - 2):
- caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
- caffe2_xavier_init(self.output_convs[i].conv, bias=0)
- caffe2_xavier_init(self.mask_feature, bias=0)
- caffe2_xavier_init(self.last_feat_conv, bias=0)
- def forward(self, feats: List[Tensor],
- batch_img_metas: List[dict]) -> Tuple[Tensor, Tensor]:
- """
- Args:
- feats (list[Tensor]): Feature maps of each level. Each has
- shape of (batch_size, c, h, w).
- batch_img_metas (list[dict]): List of image information.
- Pass in for creating more accurate padding mask. Not
- used here.
- Returns:
- tuple[Tensor, Tensor]: a tuple containing the following:
- - mask_feature (Tensor): Shape (batch_size, c, h, w).
- - memory (Tensor): Output of last stage of backbone.\
- Shape (batch_size, c, h, w).
- """
- y = self.last_feat_conv(feats[-1])
- for i in range(self.num_inputs - 2, -1, -1):
- x = feats[i]
- cur_feat = self.lateral_convs[i](x)
- y = cur_feat + \
- F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest')
- y = self.output_convs[i](y)
- mask_feature = self.mask_feature(y)
- memory = feats[-1]
- return mask_feature, memory
- @MODELS.register_module()
- class TransformerEncoderPixelDecoder(PixelDecoder):
- """Pixel decoder with transormer encoder inside.
- Args:
- in_channels (list[int] | tuple[int]): Number of channels in the
- input feature maps.
- feat_channels (int): Number channels for feature.
- out_channels (int): Number channels for output.
- norm_cfg (:obj:`ConfigDict` or dict): Config for normalization.
- Defaults to dict(type='GN', num_groups=32).
- act_cfg (:obj:`ConfigDict` or dict): Config for activation.
- Defaults to dict(type='ReLU').
- encoder (:obj:`ConfigDict` or dict): Config for transformer encoder.
- Defaults to None.
- positional_encoding (:obj:`ConfigDict` or dict): Config for
- transformer encoder position encoding. Defaults to
- dict(num_feats=128, normalize=True).
- init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
- dict], optional): Initialization config dict. Defaults to None.
- """
- def __init__(self,
- in_channels: Union[List[int], Tuple[int]],
- feat_channels: int,
- out_channels: int,
- norm_cfg: ConfigType = dict(type='GN', num_groups=32),
- act_cfg: ConfigType = dict(type='ReLU'),
- encoder: ConfigType = None,
- positional_encoding: ConfigType = dict(
- num_feats=128, normalize=True),
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(
- in_channels=in_channels,
- feat_channels=feat_channels,
- out_channels=out_channels,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- init_cfg=init_cfg)
- self.last_feat_conv = None
- self.encoder = DetrTransformerEncoder(**encoder)
- self.encoder_embed_dims = self.encoder.embed_dims
- assert self.encoder_embed_dims == feat_channels, 'embed_dims({}) of ' \
- 'tranformer encoder must equal to feat_channels({})'.format(
- feat_channels, self.encoder_embed_dims)
- self.positional_encoding = SinePositionalEncoding(
- **positional_encoding)
- self.encoder_in_proj = Conv2d(
- in_channels[-1], feat_channels, kernel_size=1)
- self.encoder_out_proj = ConvModule(
- feat_channels,
- feat_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=self.use_bias,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- def init_weights(self) -> None:
- """Initialize weights."""
- for i in range(0, self.num_inputs - 2):
- caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
- caffe2_xavier_init(self.output_convs[i].conv, bias=0)
- caffe2_xavier_init(self.mask_feature, bias=0)
- caffe2_xavier_init(self.encoder_in_proj, bias=0)
- caffe2_xavier_init(self.encoder_out_proj.conv, bias=0)
- for p in self.encoder.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- def forward(self, feats: List[Tensor],
- batch_img_metas: List[dict]) -> Tuple[Tensor, Tensor]:
- """
- Args:
- feats (list[Tensor]): Feature maps of each level. Each has
- shape of (batch_size, c, h, w).
- batch_img_metas (list[dict]): List of image information. Pass in
- for creating more accurate padding mask.
- Returns:
- tuple: a tuple containing the following:
- - mask_feature (Tensor): shape (batch_size, c, h, w).
- - memory (Tensor): shape (batch_size, c, h, w).
- """
- feat_last = feats[-1]
- bs, c, h, w = feat_last.shape
- input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape']
- padding_mask = feat_last.new_ones((bs, input_img_h, input_img_w),
- dtype=torch.float32)
- for i in range(bs):
- img_h, img_w = batch_img_metas[i]['img_shape']
- padding_mask[i, :img_h, :img_w] = 0
- padding_mask = F.interpolate(
- padding_mask.unsqueeze(1),
- size=feat_last.shape[-2:],
- mode='nearest').to(torch.bool).squeeze(1)
- pos_embed = self.positional_encoding(padding_mask)
- feat_last = self.encoder_in_proj(feat_last)
- # (batch_size, c, h, w) -> (batch_size, num_queries, c)
- feat_last = feat_last.flatten(2).permute(0, 2, 1)
- pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
- # (batch_size, h, w) -> (batch_size, h*w)
- padding_mask = padding_mask.flatten(1)
- memory = self.encoder(
- query=feat_last,
- query_pos=pos_embed,
- key_padding_mask=padding_mask)
- # (batch_size, num_queries, c) -> (batch_size, c, h, w)
- memory = memory.permute(0, 2, 1).view(bs, self.encoder_embed_dims, h,
- w)
- y = self.encoder_out_proj(memory)
- for i in range(self.num_inputs - 2, -1, -1):
- x = feats[i]
- cur_feat = self.lateral_convs[i](x)
- y = cur_feat + \
- F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest')
- y = self.output_convs[i](y)
- mask_feature = self.mask_feature(y)
- return mask_feature, memory
|