import copy import math from typing import Optional import torch import torch.nn.functional as F from torch import Tensor, nn # modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/body/transformer_blocks.py # noqa """Transformer class. Copy-paste from torch.nn.Transformer with modifications: * positional encodings are passed in MHattention * extra LN at the end of encoder is removed * decoder returns a stack of activations from all decoding layers """ class Conv2d(torch.nn.Conv2d): """A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.""" def __init__(self, *args, **kwargs): """Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: Args: norm (nn.Module, optional): a normalization layer activation (callable(Tensor) -> Tensor): a callable activation function It assumes that norm layer is used before activation. """ norm = kwargs.pop('norm', None) activation = kwargs.pop('activation', None) super().__init__(*args, **kwargs) self.norm = norm self.activation = activation def forward(self, x): x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x class PositionEmbeddingSine(nn.Module): """This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images.""" def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError('normalize should be True if scale is passed') if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, x, mask=None): if mask is None: mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=x.dtype) x_embed = not_mask.cumsum(2, dtype=x.dtype) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange( self.num_pos_feats, dtype=x.dtype, device=x.device) dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos def __repr__(self, _repr_indent=4): head = 'Positional encoding ' + self.__class__.__name__ body = [ 'num_pos_feats: {}'.format(self.num_pos_feats), 'temperature: {}'.format(self.temperature), 'normalize: {}'.format(self.normalize), 'scale: {}'.format(self.scale), ] # _repr_indent = 4 lines = [head] + [' ' * _repr_indent + line for line in body] return '\n'.join(lines) class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): output = src for layer in self.layers: output = layer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) if self.norm is not None: output = self.norm(output) return output class TransformerEncoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu', normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(src, pos) src2 = self.self_attn( q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src def forward_pre( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) src2 = self.self_attn( q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src def forward( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre(src, src_mask, src_key_padding_mask, pos) return self.forward_post(src, src_mask, src_key_padding_mask, pos) class SelfAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation='relu', normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn( q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) class CrossAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation='relu', normalize_before=False): super().__init__() self.multihead_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2, avg_attn = self.multihead_attn( query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt, avg_attn def forward_pre(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) tgt2, avg_attn = self.multihead_attn( query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask) tgt = tgt + self.dropout(tgt2) return tgt, avg_attn def forward(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) class FFNLayer(nn.Module): def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation='relu', normalize_before=False): super().__init__() # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm = nn.LayerNorm(d_model) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt): tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt): tgt2 = self.norm(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt): if self.normalize_before: return self.forward_pre(tgt) return self.forward_post(tgt) class MLP(nn.Module): """Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def get_norm(norm, out_channels): """ Args: norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; or a callable that takes a channel number and returns the normalization layer as a nn.Module. Returns: nn.Module or None: the normalization layer """ if norm is None: return None if isinstance(norm, str): if len(norm) == 0: return None norm = { 'BN': nn.BatchNorm2d, 'GN': lambda channels: nn.GroupNorm(32, channels), }[norm] return norm(out_channels) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string.""" if activation == 'relu': return F.relu if activation == 'gelu': return F.gelu if activation == 'glu': return F.glu raise RuntimeError(f'activation should be relu/gelu, not {activation}.')