123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- from typing import Callable, Optional, Union
- from torch import nn
- from torch.nn import functional as F
- from mmdet.registry import MODELS
- from .transformer_blocks import (Conv2d, PositionEmbeddingSine,
- TransformerEncoder, TransformerEncoderLayer,
- get_norm)
- # modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/body/encoder/transformer_encoder_fpn.py # noqa
- class TransformerEncoderOnly(nn.Module):
- def __init__(self,
- d_model=512,
- nhead=8,
- num_encoder_layers=6,
- dim_feedforward=2048,
- dropout=0.1,
- activation='relu',
- normalize_before=False):
- super().__init__()
- encoder_layer = TransformerEncoderLayer(d_model, nhead,
- dim_feedforward, dropout,
- activation, normalize_before)
- encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
- self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers,
- encoder_norm)
- self._reset_parameters()
- self.d_model = d_model
- self.nhead = nhead
- def _reset_parameters(self):
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- def forward(self, src, mask, pos_embed):
- # flatten NxCxHxW to HWxNxC
- bs, c, h, w = src.shape
- src = src.flatten(2).permute(2, 0, 1)
- pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
- if mask is not None:
- mask = mask.flatten(1)
- memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
- return memory.permute(1, 2, 0).view(bs, c, h, w)
- class BasePixelDecoder(nn.Module):
- def __init__(
- self,
- in_channels,
- conv_dim: int,
- mask_dim: int,
- mask_on: bool,
- norm: Optional[Union[str, Callable]] = None,
- ):
- super().__init__()
- lateral_convs = []
- output_convs = []
- use_bias = norm == ''
- for idx, in_channel in enumerate(in_channels):
- if idx == len(in_channels) - 1:
- output_norm = get_norm(norm, conv_dim)
- output_conv = Conv2d(
- in_channel,
- conv_dim,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=use_bias,
- norm=output_norm,
- activation=F.relu,
- )
- self.add_module('layer_{}'.format(idx + 1), output_conv)
- lateral_convs.append(None)
- output_convs.append(output_conv)
- else:
- lateral_norm = get_norm(norm, conv_dim)
- output_norm = get_norm(norm, conv_dim)
- lateral_conv = Conv2d(
- in_channel,
- conv_dim,
- kernel_size=1,
- bias=use_bias,
- norm=lateral_norm)
- output_conv = Conv2d(
- conv_dim,
- conv_dim,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=use_bias,
- norm=output_norm,
- activation=F.relu,
- )
- self.add_module('adapter_{}'.format(idx + 1), lateral_conv)
- self.add_module('layer_{}'.format(idx + 1), output_conv)
- lateral_convs.append(lateral_conv)
- output_convs.append(output_conv)
- # Place convs into top-down order (from low to high resolution)
- # to make the top-down computation in forward clearer.
- self.lateral_convs = lateral_convs[::-1]
- self.output_convs = output_convs[::-1]
- self.mask_on = mask_on
- if self.mask_on:
- self.mask_dim = mask_dim
- self.mask_features = Conv2d(
- conv_dim,
- mask_dim,
- kernel_size=3,
- stride=1,
- padding=1,
- )
- self.maskformer_num_feature_levels = 3
- # To prevent conflicts with TransformerEncoderPixelDecoder in mask2former,
- # we change the name to XTransformerEncoderPixelDecoder
- @MODELS.register_module()
- class XTransformerEncoderPixelDecoder(BasePixelDecoder):
- def __init__(
- self,
- in_channels,
- transformer_dropout: float = 0.0,
- transformer_nheads: int = 8,
- transformer_dim_feedforward: int = 2048,
- transformer_enc_layers: int = 6,
- transformer_pre_norm: bool = False,
- conv_dim: int = 512,
- mask_dim: int = 512,
- norm: Optional[Union[str, Callable]] = 'GN',
- ):
- super().__init__(
- in_channels,
- conv_dim=conv_dim,
- mask_dim=mask_dim,
- norm=norm,
- mask_on=True)
- self.in_features = ['res2', 'res3', 'res4', 'res5']
- feature_channels = in_channels
- in_channels = feature_channels[len(in_channels) - 1]
- self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
- self.transformer = TransformerEncoderOnly(
- d_model=conv_dim,
- dropout=transformer_dropout,
- nhead=transformer_nheads,
- dim_feedforward=transformer_dim_feedforward,
- num_encoder_layers=transformer_enc_layers,
- normalize_before=transformer_pre_norm,
- )
- self.pe_layer = PositionEmbeddingSine(conv_dim // 2, normalize=True)
- # update layer
- use_bias = norm == ''
- output_norm = get_norm(norm, conv_dim)
- output_conv = Conv2d(
- conv_dim,
- conv_dim,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=use_bias,
- norm=output_norm,
- activation=F.relu,
- )
- delattr(self, 'layer_{}'.format(len(self.in_features)))
- self.add_module('layer_{}'.format(len(self.in_features)), output_conv)
- self.output_convs[0] = output_conv
- def forward(self, features):
- multi_scale_features = []
- num_cur_levels = 0
- # Reverse feature maps into top-down order
- # (from low to high resolution)
- for idx, f in enumerate(self.in_features[::-1]):
- x = features[f]
- lateral_conv = self.lateral_convs[idx]
- output_conv = self.output_convs[idx]
- if lateral_conv is None:
- transformer = self.input_proj(x)
- pos = self.pe_layer(x)
- transformer = self.transformer(transformer, None, pos)
- y = output_conv(transformer)
- else:
- cur_fpn = lateral_conv(x)
- # Following FPN implementation, we use nearest upsampling here
- y = cur_fpn + F.interpolate(
- y, size=cur_fpn.shape[-2:], mode='nearest')
- y = output_conv(y)
- if num_cur_levels < self.maskformer_num_feature_levels:
- multi_scale_features.append(y)
- num_cur_levels += 1
- mask_features = self.mask_features(y)
- return mask_features, multi_scale_features
|