123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473 |
- import copy
- import math
- from typing import Optional
- import torch
- import torch.nn.functional as F
- from torch import Tensor, nn
- # modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/body/transformer_blocks.py # noqa
- """Transformer class.
- Copy-paste from torch.nn.Transformer with modifications:
- * positional encodings are passed in MHattention
- * extra LN at the end of encoder is removed
- * decoder returns a stack of activations from all decoding layers
- """
- class Conv2d(torch.nn.Conv2d):
- """A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and
- more features."""
- def __init__(self, *args, **kwargs):
- """Extra keyword arguments supported in addition to those in
- `torch.nn.Conv2d`:
- Args:
- norm (nn.Module, optional): a normalization layer
- activation (callable(Tensor) -> Tensor): a callable
- activation function
- It assumes that norm layer is used before activation.
- """
- norm = kwargs.pop('norm', None)
- activation = kwargs.pop('activation', None)
- super().__init__(*args, **kwargs)
- self.norm = norm
- self.activation = activation
- def forward(self, x):
- x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
- self.dilation, self.groups)
- if self.norm is not None:
- x = self.norm(x)
- if self.activation is not None:
- x = self.activation(x)
- return x
- class PositionEmbeddingSine(nn.Module):
- """This is a more standard version of the position embedding, very similar
- to the one used by the Attention is all you need paper, generalized to work
- on images."""
- def __init__(self,
- num_pos_feats=64,
- temperature=10000,
- normalize=False,
- scale=None):
- super().__init__()
- self.num_pos_feats = num_pos_feats
- self.temperature = temperature
- self.normalize = normalize
- if scale is not None and normalize is False:
- raise ValueError('normalize should be True if scale is passed')
- if scale is None:
- scale = 2 * math.pi
- self.scale = scale
- def forward(self, x, mask=None):
- if mask is None:
- mask = torch.zeros((x.size(0), x.size(2), x.size(3)),
- device=x.device,
- dtype=torch.bool)
- not_mask = ~mask
- y_embed = not_mask.cumsum(1, dtype=x.dtype)
- x_embed = not_mask.cumsum(2, dtype=x.dtype)
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
- dim_t = torch.arange(
- self.num_pos_feats, dtype=x.dtype, device=x.device)
- dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats)
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack(
- (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
- dim=4).flatten(3)
- pos_y = torch.stack(
- (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
- dim=4).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- return pos
- def __repr__(self, _repr_indent=4):
- head = 'Positional encoding ' + self.__class__.__name__
- body = [
- 'num_pos_feats: {}'.format(self.num_pos_feats),
- 'temperature: {}'.format(self.temperature),
- 'normalize: {}'.format(self.normalize),
- 'scale: {}'.format(self.scale),
- ]
- # _repr_indent = 4
- lines = [head] + [' ' * _repr_indent + line for line in body]
- return '\n'.join(lines)
- class TransformerEncoder(nn.Module):
- def __init__(self, encoder_layer, num_layers, norm=None):
- super().__init__()
- self.layers = _get_clones(encoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
- def forward(
- self,
- src,
- mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- ):
- output = src
- for layer in self.layers:
- output = layer(
- output,
- src_mask=mask,
- src_key_padding_mask=src_key_padding_mask,
- pos=pos)
- if self.norm is not None:
- output = self.norm(output)
- return output
- class TransformerEncoderLayer(nn.Module):
- def __init__(
- self,
- d_model,
- nhead,
- dim_feedforward=2048,
- dropout=0.1,
- activation='relu',
- normalize_before=False,
- ):
- super().__init__()
- self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
- # Implementation of Feedforward model
- self.linear1 = nn.Linear(d_model, dim_feedforward)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_feedforward, d_model)
- self.norm1 = nn.LayerNorm(d_model)
- self.norm2 = nn.LayerNorm(d_model)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.activation = _get_activation_fn(activation)
- self.normalize_before = normalize_before
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
- return tensor if pos is None else tensor + pos
- def forward_post(
- self,
- src,
- src_mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- ):
- q = k = self.with_pos_embed(src, pos)
- src2 = self.self_attn(
- q,
- k,
- value=src,
- attn_mask=src_mask,
- key_padding_mask=src_key_padding_mask)[0]
- src = src + self.dropout1(src2)
- src = self.norm1(src)
- src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
- src = src + self.dropout2(src2)
- src = self.norm2(src)
- return src
- def forward_pre(
- self,
- src,
- src_mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- ):
- src2 = self.norm1(src)
- q = k = self.with_pos_embed(src2, pos)
- src2 = self.self_attn(
- q,
- k,
- value=src2,
- attn_mask=src_mask,
- key_padding_mask=src_key_padding_mask)[0]
- src = src + self.dropout1(src2)
- src2 = self.norm2(src)
- src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
- src = src + self.dropout2(src2)
- return src
- def forward(
- self,
- src,
- src_mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- ):
- if self.normalize_before:
- return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
- return self.forward_post(src, src_mask, src_key_padding_mask, pos)
- class SelfAttentionLayer(nn.Module):
- def __init__(self,
- d_model,
- nhead,
- dropout=0.0,
- activation='relu',
- normalize_before=False):
- super().__init__()
- self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
- self.norm = nn.LayerNorm(d_model)
- self.dropout = nn.Dropout(dropout)
- self.activation = _get_activation_fn(activation)
- self.normalize_before = normalize_before
- self._reset_parameters()
- def _reset_parameters(self):
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
- return tensor if pos is None else tensor + pos
- def forward_post(self,
- tgt,
- tgt_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None):
- q = k = self.with_pos_embed(tgt, query_pos)
- tgt2 = self.self_attn(
- q,
- k,
- value=tgt,
- attn_mask=tgt_mask,
- key_padding_mask=tgt_key_padding_mask)[0]
- tgt = tgt + self.dropout(tgt2)
- tgt = self.norm(tgt)
- return tgt
- def forward_pre(self,
- tgt,
- tgt_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None):
- tgt2 = self.norm(tgt)
- q = k = self.with_pos_embed(tgt2, query_pos)
- tgt2 = self.self_attn(
- q,
- k,
- value=tgt2,
- attn_mask=tgt_mask,
- key_padding_mask=tgt_key_padding_mask)[0]
- tgt = tgt + self.dropout(tgt2)
- return tgt
- def forward(self,
- tgt,
- tgt_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None):
- if self.normalize_before:
- return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask,
- query_pos)
- return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask,
- query_pos)
- class CrossAttentionLayer(nn.Module):
- def __init__(self,
- d_model,
- nhead,
- dropout=0.0,
- activation='relu',
- normalize_before=False):
- super().__init__()
- self.multihead_attn = nn.MultiheadAttention(
- d_model, nhead, dropout=dropout)
- self.norm = nn.LayerNorm(d_model)
- self.dropout = nn.Dropout(dropout)
- self.activation = _get_activation_fn(activation)
- self.normalize_before = normalize_before
- self._reset_parameters()
- def _reset_parameters(self):
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
- return tensor if pos is None else tensor + pos
- def forward_post(self,
- tgt,
- memory,
- memory_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None):
- tgt2, avg_attn = self.multihead_attn(
- query=self.with_pos_embed(tgt, query_pos),
- key=self.with_pos_embed(memory, pos),
- value=memory,
- attn_mask=memory_mask,
- key_padding_mask=memory_key_padding_mask)
- tgt = tgt + self.dropout(tgt2)
- tgt = self.norm(tgt)
- return tgt, avg_attn
- def forward_pre(self,
- tgt,
- memory,
- memory_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None):
- tgt2 = self.norm(tgt)
- tgt2, avg_attn = self.multihead_attn(
- query=self.with_pos_embed(tgt2, query_pos),
- key=self.with_pos_embed(memory, pos),
- value=memory,
- attn_mask=memory_mask,
- key_padding_mask=memory_key_padding_mask)
- tgt = tgt + self.dropout(tgt2)
- return tgt, avg_attn
- def forward(self,
- tgt,
- memory,
- memory_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None):
- if self.normalize_before:
- return self.forward_pre(tgt, memory, memory_mask,
- memory_key_padding_mask, pos, query_pos)
- return self.forward_post(tgt, memory, memory_mask,
- memory_key_padding_mask, pos, query_pos)
- class FFNLayer(nn.Module):
- def __init__(self,
- d_model,
- dim_feedforward=2048,
- dropout=0.0,
- activation='relu',
- normalize_before=False):
- super().__init__()
- # Implementation of Feedforward model
- self.linear1 = nn.Linear(d_model, dim_feedforward)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_feedforward, d_model)
- self.norm = nn.LayerNorm(d_model)
- self.activation = _get_activation_fn(activation)
- self.normalize_before = normalize_before
- self._reset_parameters()
- def _reset_parameters(self):
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
- return tensor if pos is None else tensor + pos
- def forward_post(self, tgt):
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
- tgt = tgt + self.dropout(tgt2)
- tgt = self.norm(tgt)
- return tgt
- def forward_pre(self, tgt):
- tgt2 = self.norm(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
- tgt = tgt + self.dropout(tgt2)
- return tgt
- def forward(self, tgt):
- if self.normalize_before:
- return self.forward_pre(tgt)
- return self.forward_post(tgt)
- class MLP(nn.Module):
- """Very simple multi-layer perceptron (also called FFN)"""
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
- super().__init__()
- self.num_layers = num_layers
- h = [hidden_dim] * (num_layers - 1)
- self.layers = nn.ModuleList(
- nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
- def forward(self, x):
- for i, layer in enumerate(self.layers):
- x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
- return x
- def get_norm(norm, out_channels):
- """
- Args:
- norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
- or a callable that takes a channel number and returns
- the normalization layer as a nn.Module.
- Returns:
- nn.Module or None: the normalization layer
- """
- if norm is None:
- return None
- if isinstance(norm, str):
- if len(norm) == 0:
- return None
- norm = {
- 'BN': nn.BatchNorm2d,
- 'GN': lambda channels: nn.GroupNorm(32, channels),
- }[norm]
- return norm(out_channels)
- def _get_clones(module, N):
- return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
- def _get_activation_fn(activation):
- """Return an activation function given a string."""
- if activation == 'relu':
- return F.relu
- if activation == 'gelu':
- return F.gelu
- if activation == 'glu':
- return F.glu
- raise RuntimeError(f'activation should be relu/gelu, not {activation}.')
|