pixel_decoder.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import Conv2d, ConvModule
  7. from mmengine.model import BaseModule, ModuleList, caffe2_xavier_init
  8. from torch import Tensor
  9. from mmdet.registry import MODELS
  10. from mmdet.utils import ConfigType, OptMultiConfig
  11. from .positional_encoding import SinePositionalEncoding
  12. from .transformer import DetrTransformerEncoder
  13. @MODELS.register_module()
  14. class PixelDecoder(BaseModule):
  15. """Pixel decoder with a structure like fpn.
  16. Args:
  17. in_channels (list[int] | tuple[int]): Number of channels in the
  18. input feature maps.
  19. feat_channels (int): Number channels for feature.
  20. out_channels (int): Number channels for output.
  21. norm_cfg (:obj:`ConfigDict` or dict): Config for normalization.
  22. Defaults to dict(type='GN', num_groups=32).
  23. act_cfg (:obj:`ConfigDict` or dict): Config for activation.
  24. Defaults to dict(type='ReLU').
  25. encoder (:obj:`ConfigDict` or dict): Config for transorformer
  26. encoder.Defaults to None.
  27. positional_encoding (:obj:`ConfigDict` or dict): Config for
  28. transformer encoder position encoding. Defaults to
  29. dict(type='SinePositionalEncoding', num_feats=128,
  30. normalize=True).
  31. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  32. dict], optional): Initialization config dict. Defaults to None.
  33. """
  34. def __init__(self,
  35. in_channels: Union[List[int], Tuple[int]],
  36. feat_channels: int,
  37. out_channels: int,
  38. norm_cfg: ConfigType = dict(type='GN', num_groups=32),
  39. act_cfg: ConfigType = dict(type='ReLU'),
  40. init_cfg: OptMultiConfig = None) -> None:
  41. super().__init__(init_cfg=init_cfg)
  42. self.in_channels = in_channels
  43. self.num_inputs = len(in_channels)
  44. self.lateral_convs = ModuleList()
  45. self.output_convs = ModuleList()
  46. self.use_bias = norm_cfg is None
  47. for i in range(0, self.num_inputs - 1):
  48. lateral_conv = ConvModule(
  49. in_channels[i],
  50. feat_channels,
  51. kernel_size=1,
  52. bias=self.use_bias,
  53. norm_cfg=norm_cfg,
  54. act_cfg=None)
  55. output_conv = ConvModule(
  56. feat_channels,
  57. feat_channels,
  58. kernel_size=3,
  59. stride=1,
  60. padding=1,
  61. bias=self.use_bias,
  62. norm_cfg=norm_cfg,
  63. act_cfg=act_cfg)
  64. self.lateral_convs.append(lateral_conv)
  65. self.output_convs.append(output_conv)
  66. self.last_feat_conv = ConvModule(
  67. in_channels[-1],
  68. feat_channels,
  69. kernel_size=3,
  70. padding=1,
  71. stride=1,
  72. bias=self.use_bias,
  73. norm_cfg=norm_cfg,
  74. act_cfg=act_cfg)
  75. self.mask_feature = Conv2d(
  76. feat_channels, out_channels, kernel_size=3, stride=1, padding=1)
  77. def init_weights(self) -> None:
  78. """Initialize weights."""
  79. for i in range(0, self.num_inputs - 2):
  80. caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
  81. caffe2_xavier_init(self.output_convs[i].conv, bias=0)
  82. caffe2_xavier_init(self.mask_feature, bias=0)
  83. caffe2_xavier_init(self.last_feat_conv, bias=0)
  84. def forward(self, feats: List[Tensor],
  85. batch_img_metas: List[dict]) -> Tuple[Tensor, Tensor]:
  86. """
  87. Args:
  88. feats (list[Tensor]): Feature maps of each level. Each has
  89. shape of (batch_size, c, h, w).
  90. batch_img_metas (list[dict]): List of image information.
  91. Pass in for creating more accurate padding mask. Not
  92. used here.
  93. Returns:
  94. tuple[Tensor, Tensor]: a tuple containing the following:
  95. - mask_feature (Tensor): Shape (batch_size, c, h, w).
  96. - memory (Tensor): Output of last stage of backbone.\
  97. Shape (batch_size, c, h, w).
  98. """
  99. y = self.last_feat_conv(feats[-1])
  100. for i in range(self.num_inputs - 2, -1, -1):
  101. x = feats[i]
  102. cur_feat = self.lateral_convs[i](x)
  103. y = cur_feat + \
  104. F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest')
  105. y = self.output_convs[i](y)
  106. mask_feature = self.mask_feature(y)
  107. memory = feats[-1]
  108. return mask_feature, memory
  109. @MODELS.register_module()
  110. class TransformerEncoderPixelDecoder(PixelDecoder):
  111. """Pixel decoder with transormer encoder inside.
  112. Args:
  113. in_channels (list[int] | tuple[int]): Number of channels in the
  114. input feature maps.
  115. feat_channels (int): Number channels for feature.
  116. out_channels (int): Number channels for output.
  117. norm_cfg (:obj:`ConfigDict` or dict): Config for normalization.
  118. Defaults to dict(type='GN', num_groups=32).
  119. act_cfg (:obj:`ConfigDict` or dict): Config for activation.
  120. Defaults to dict(type='ReLU').
  121. encoder (:obj:`ConfigDict` or dict): Config for transformer encoder.
  122. Defaults to None.
  123. positional_encoding (:obj:`ConfigDict` or dict): Config for
  124. transformer encoder position encoding. Defaults to
  125. dict(num_feats=128, normalize=True).
  126. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  127. dict], optional): Initialization config dict. Defaults to None.
  128. """
  129. def __init__(self,
  130. in_channels: Union[List[int], Tuple[int]],
  131. feat_channels: int,
  132. out_channels: int,
  133. norm_cfg: ConfigType = dict(type='GN', num_groups=32),
  134. act_cfg: ConfigType = dict(type='ReLU'),
  135. encoder: ConfigType = None,
  136. positional_encoding: ConfigType = dict(
  137. num_feats=128, normalize=True),
  138. init_cfg: OptMultiConfig = None) -> None:
  139. super().__init__(
  140. in_channels=in_channels,
  141. feat_channels=feat_channels,
  142. out_channels=out_channels,
  143. norm_cfg=norm_cfg,
  144. act_cfg=act_cfg,
  145. init_cfg=init_cfg)
  146. self.last_feat_conv = None
  147. self.encoder = DetrTransformerEncoder(**encoder)
  148. self.encoder_embed_dims = self.encoder.embed_dims
  149. assert self.encoder_embed_dims == feat_channels, 'embed_dims({}) of ' \
  150. 'tranformer encoder must equal to feat_channels({})'.format(
  151. feat_channels, self.encoder_embed_dims)
  152. self.positional_encoding = SinePositionalEncoding(
  153. **positional_encoding)
  154. self.encoder_in_proj = Conv2d(
  155. in_channels[-1], feat_channels, kernel_size=1)
  156. self.encoder_out_proj = ConvModule(
  157. feat_channels,
  158. feat_channels,
  159. kernel_size=3,
  160. stride=1,
  161. padding=1,
  162. bias=self.use_bias,
  163. norm_cfg=norm_cfg,
  164. act_cfg=act_cfg)
  165. def init_weights(self) -> None:
  166. """Initialize weights."""
  167. for i in range(0, self.num_inputs - 2):
  168. caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
  169. caffe2_xavier_init(self.output_convs[i].conv, bias=0)
  170. caffe2_xavier_init(self.mask_feature, bias=0)
  171. caffe2_xavier_init(self.encoder_in_proj, bias=0)
  172. caffe2_xavier_init(self.encoder_out_proj.conv, bias=0)
  173. for p in self.encoder.parameters():
  174. if p.dim() > 1:
  175. nn.init.xavier_uniform_(p)
  176. def forward(self, feats: List[Tensor],
  177. batch_img_metas: List[dict]) -> Tuple[Tensor, Tensor]:
  178. """
  179. Args:
  180. feats (list[Tensor]): Feature maps of each level. Each has
  181. shape of (batch_size, c, h, w).
  182. batch_img_metas (list[dict]): List of image information. Pass in
  183. for creating more accurate padding mask.
  184. Returns:
  185. tuple: a tuple containing the following:
  186. - mask_feature (Tensor): shape (batch_size, c, h, w).
  187. - memory (Tensor): shape (batch_size, c, h, w).
  188. """
  189. feat_last = feats[-1]
  190. bs, c, h, w = feat_last.shape
  191. input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape']
  192. padding_mask = feat_last.new_ones((bs, input_img_h, input_img_w),
  193. dtype=torch.float32)
  194. for i in range(bs):
  195. img_h, img_w = batch_img_metas[i]['img_shape']
  196. padding_mask[i, :img_h, :img_w] = 0
  197. padding_mask = F.interpolate(
  198. padding_mask.unsqueeze(1),
  199. size=feat_last.shape[-2:],
  200. mode='nearest').to(torch.bool).squeeze(1)
  201. pos_embed = self.positional_encoding(padding_mask)
  202. feat_last = self.encoder_in_proj(feat_last)
  203. # (batch_size, c, h, w) -> (batch_size, num_queries, c)
  204. feat_last = feat_last.flatten(2).permute(0, 2, 1)
  205. pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
  206. # (batch_size, h, w) -> (batch_size, h*w)
  207. padding_mask = padding_mask.flatten(1)
  208. memory = self.encoder(
  209. query=feat_last,
  210. query_pos=pos_embed,
  211. key_padding_mask=padding_mask)
  212. # (batch_size, num_queries, c) -> (batch_size, c, h, w)
  213. memory = memory.permute(0, 2, 1).view(bs, self.encoder_embed_dims, h,
  214. w)
  215. y = self.encoder_out_proj(memory)
  216. for i in range(self.num_inputs - 2, -1, -1):
  217. x = feats[i]
  218. cur_feat = self.lateral_convs[i](x)
  219. y = cur_feat + \
  220. F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest')
  221. y = self.output_convs[i](y)
  222. mask_feature = self.mask_feature(y)
  223. return mask_feature, memory