pixel_decoder.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from typing import Callable, Optional, Union
  2. from torch import nn
  3. from torch.nn import functional as F
  4. from mmdet.registry import MODELS
  5. from .transformer_blocks import (Conv2d, PositionEmbeddingSine,
  6. TransformerEncoder, TransformerEncoderLayer,
  7. get_norm)
  8. # modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/body/encoder/transformer_encoder_fpn.py # noqa
  9. class TransformerEncoderOnly(nn.Module):
  10. def __init__(self,
  11. d_model=512,
  12. nhead=8,
  13. num_encoder_layers=6,
  14. dim_feedforward=2048,
  15. dropout=0.1,
  16. activation='relu',
  17. normalize_before=False):
  18. super().__init__()
  19. encoder_layer = TransformerEncoderLayer(d_model, nhead,
  20. dim_feedforward, dropout,
  21. activation, normalize_before)
  22. encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
  23. self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers,
  24. encoder_norm)
  25. self._reset_parameters()
  26. self.d_model = d_model
  27. self.nhead = nhead
  28. def _reset_parameters(self):
  29. for p in self.parameters():
  30. if p.dim() > 1:
  31. nn.init.xavier_uniform_(p)
  32. def forward(self, src, mask, pos_embed):
  33. # flatten NxCxHxW to HWxNxC
  34. bs, c, h, w = src.shape
  35. src = src.flatten(2).permute(2, 0, 1)
  36. pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
  37. if mask is not None:
  38. mask = mask.flatten(1)
  39. memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
  40. return memory.permute(1, 2, 0).view(bs, c, h, w)
  41. class BasePixelDecoder(nn.Module):
  42. def __init__(
  43. self,
  44. in_channels,
  45. conv_dim: int,
  46. mask_dim: int,
  47. mask_on: bool,
  48. norm: Optional[Union[str, Callable]] = None,
  49. ):
  50. super().__init__()
  51. lateral_convs = []
  52. output_convs = []
  53. use_bias = norm == ''
  54. for idx, in_channel in enumerate(in_channels):
  55. if idx == len(in_channels) - 1:
  56. output_norm = get_norm(norm, conv_dim)
  57. output_conv = Conv2d(
  58. in_channel,
  59. conv_dim,
  60. kernel_size=3,
  61. stride=1,
  62. padding=1,
  63. bias=use_bias,
  64. norm=output_norm,
  65. activation=F.relu,
  66. )
  67. self.add_module('layer_{}'.format(idx + 1), output_conv)
  68. lateral_convs.append(None)
  69. output_convs.append(output_conv)
  70. else:
  71. lateral_norm = get_norm(norm, conv_dim)
  72. output_norm = get_norm(norm, conv_dim)
  73. lateral_conv = Conv2d(
  74. in_channel,
  75. conv_dim,
  76. kernel_size=1,
  77. bias=use_bias,
  78. norm=lateral_norm)
  79. output_conv = Conv2d(
  80. conv_dim,
  81. conv_dim,
  82. kernel_size=3,
  83. stride=1,
  84. padding=1,
  85. bias=use_bias,
  86. norm=output_norm,
  87. activation=F.relu,
  88. )
  89. self.add_module('adapter_{}'.format(idx + 1), lateral_conv)
  90. self.add_module('layer_{}'.format(idx + 1), output_conv)
  91. lateral_convs.append(lateral_conv)
  92. output_convs.append(output_conv)
  93. # Place convs into top-down order (from low to high resolution)
  94. # to make the top-down computation in forward clearer.
  95. self.lateral_convs = lateral_convs[::-1]
  96. self.output_convs = output_convs[::-1]
  97. self.mask_on = mask_on
  98. if self.mask_on:
  99. self.mask_dim = mask_dim
  100. self.mask_features = Conv2d(
  101. conv_dim,
  102. mask_dim,
  103. kernel_size=3,
  104. stride=1,
  105. padding=1,
  106. )
  107. self.maskformer_num_feature_levels = 3
  108. # To prevent conflicts with TransformerEncoderPixelDecoder in mask2former,
  109. # we change the name to XTransformerEncoderPixelDecoder
  110. @MODELS.register_module()
  111. class XTransformerEncoderPixelDecoder(BasePixelDecoder):
  112. def __init__(
  113. self,
  114. in_channels,
  115. transformer_dropout: float = 0.0,
  116. transformer_nheads: int = 8,
  117. transformer_dim_feedforward: int = 2048,
  118. transformer_enc_layers: int = 6,
  119. transformer_pre_norm: bool = False,
  120. conv_dim: int = 512,
  121. mask_dim: int = 512,
  122. norm: Optional[Union[str, Callable]] = 'GN',
  123. ):
  124. super().__init__(
  125. in_channels,
  126. conv_dim=conv_dim,
  127. mask_dim=mask_dim,
  128. norm=norm,
  129. mask_on=True)
  130. self.in_features = ['res2', 'res3', 'res4', 'res5']
  131. feature_channels = in_channels
  132. in_channels = feature_channels[len(in_channels) - 1]
  133. self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
  134. self.transformer = TransformerEncoderOnly(
  135. d_model=conv_dim,
  136. dropout=transformer_dropout,
  137. nhead=transformer_nheads,
  138. dim_feedforward=transformer_dim_feedforward,
  139. num_encoder_layers=transformer_enc_layers,
  140. normalize_before=transformer_pre_norm,
  141. )
  142. self.pe_layer = PositionEmbeddingSine(conv_dim // 2, normalize=True)
  143. # update layer
  144. use_bias = norm == ''
  145. output_norm = get_norm(norm, conv_dim)
  146. output_conv = Conv2d(
  147. conv_dim,
  148. conv_dim,
  149. kernel_size=3,
  150. stride=1,
  151. padding=1,
  152. bias=use_bias,
  153. norm=output_norm,
  154. activation=F.relu,
  155. )
  156. delattr(self, 'layer_{}'.format(len(self.in_features)))
  157. self.add_module('layer_{}'.format(len(self.in_features)), output_conv)
  158. self.output_convs[0] = output_conv
  159. def forward(self, features):
  160. multi_scale_features = []
  161. num_cur_levels = 0
  162. # Reverse feature maps into top-down order
  163. # (from low to high resolution)
  164. for idx, f in enumerate(self.in_features[::-1]):
  165. x = features[f]
  166. lateral_conv = self.lateral_convs[idx]
  167. output_conv = self.output_convs[idx]
  168. if lateral_conv is None:
  169. transformer = self.input_proj(x)
  170. pos = self.pe_layer(x)
  171. transformer = self.transformer(transformer, None, pos)
  172. y = output_conv(transformer)
  173. else:
  174. cur_fpn = lateral_conv(x)
  175. # Following FPN implementation, we use nearest upsampling here
  176. y = cur_fpn + F.interpolate(
  177. y, size=cur_fpn.shape[-2:], mode='nearest')
  178. y = output_conv(y)
  179. if num_cur_levels < self.maskformer_num_feature_levels:
  180. multi_scale_features.append(y)
  181. num_cur_levels += 1
  182. mask_features = self.mask_features(y)
  183. return mask_features, multi_scale_features