| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589 | # Copyright (c) OpenMMLab. All rights reserved.import mathimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.utils.checkpoint as checkpointfrom mmcv.cnn.bricks import DropPathfrom torch import Tensortry:    from transformers import BertPreTrainedModel    from transformers.modeling_utils import apply_chunking_to_forward    from transformers.models.bert.modeling_bert import \        BertAttention as HFBertAttention    from transformers.models.bert.modeling_bert import \        BertIntermediate as HFBertIntermediate    from transformers.models.bert.modeling_bert import \        BertOutput as HFBertOutputexcept ImportError:    BertPreTrainedModel = object    apply_chunking_to_forward = None    HFBertAttention = object    HFBertIntermediate = object    HFBertOutput = objectMAX_CLAMP_VALUE = 50000def permute_and_flatten(layer, N, A, C, H, W):    layer = layer.view(N, A, C, H, W)    layer = layer.permute(0, 3, 4, 1, 2)    layer = layer.reshape(N, -1, C)    return layerdef clamp_values(vector):    vector = torch.clamp(vector, min=-MAX_CLAMP_VALUE, max=MAX_CLAMP_VALUE)    return vectorclass BiMultiHeadAttention(nn.Module):    """Bidirectional fusion Multi-Head Attention layer."""    def __init__(self,                 v_dim: int,                 l_dim: int,                 embed_dim: int,                 num_heads: int,                 dropout: float = 0.1):        super(BiMultiHeadAttention, self).__init__()        self.embed_dim = embed_dim        self.num_heads = num_heads        self.head_dim = embed_dim // num_heads        self.v_dim = v_dim        self.l_dim = l_dim        assert (            self.head_dim * self.num_heads == self.embed_dim        ), 'embed_dim must be divisible by num_heads ' \           f'(got `embed_dim`: {self.embed_dim} ' \           f'and `num_heads`: {self.num_heads}).'        self.scale = self.head_dim**(-0.5)        self.dropout = dropout        self.v_proj = nn.Linear(self.v_dim, self.embed_dim)        self.l_proj = nn.Linear(self.l_dim, self.embed_dim)        self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)        self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)        self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)        self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)        self.stable_softmax_2d = False        self.clamp_min_for_underflow = True        self.clamp_max_for_overflow = True        self._reset_parameters()    def _shape(self, tensor: Tensor, seq_len: int, bsz: int):        return tensor.view(bsz, seq_len, self.num_heads,                           self.head_dim).transpose(1, 2).contiguous()    def _reset_parameters(self):        nn.init.xavier_uniform_(self.v_proj.weight)        self.v_proj.bias.data.fill_(0)        nn.init.xavier_uniform_(self.l_proj.weight)        self.l_proj.bias.data.fill_(0)        nn.init.xavier_uniform_(self.values_v_proj.weight)        self.values_v_proj.bias.data.fill_(0)        nn.init.xavier_uniform_(self.values_l_proj.weight)        self.values_l_proj.bias.data.fill_(0)        nn.init.xavier_uniform_(self.out_v_proj.weight)        self.out_v_proj.bias.data.fill_(0)        nn.init.xavier_uniform_(self.out_l_proj.weight)        self.out_l_proj.bias.data.fill_(0)    def forward(self, vision: Tensor, lang: Tensor, attention_mask_l=None):        bsz, tgt_len, _ = vision.size()        query_states = self.v_proj(vision) * self.scale        key_states = self._shape(self.l_proj(lang), -1, bsz)        value_v_states = self._shape(self.values_v_proj(vision), -1, bsz)        value_l_states = self._shape(self.values_l_proj(lang), -1, bsz)        proj_shape = (bsz * self.num_heads, -1, self.head_dim)        query_states = self._shape(query_states, tgt_len,                                   bsz).view(*proj_shape)        key_states = key_states.view(*proj_shape)        value_v_states = value_v_states.view(*proj_shape)        value_l_states = value_l_states.view(*proj_shape)        src_len = key_states.size(1)        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):            raise ValueError(                f'Attention weights should be of '                f'size {(bsz * self.num_heads, tgt_len, src_len)}, '                f'but is {attn_weights.size()}')        if self.stable_softmax_2d:            attn_weights = attn_weights - attn_weights.max()        if self.clamp_min_for_underflow:            # Do not increase -50000, data type half has quite limited range            attn_weights = torch.clamp(attn_weights, min=-MAX_CLAMP_VALUE)        if self.clamp_max_for_overflow:            # Do not increase 50000, data type half has quite limited range            attn_weights = torch.clamp(attn_weights, max=MAX_CLAMP_VALUE)        attn_weights_T = attn_weights.transpose(1, 2)        attn_weights_l = (            attn_weights_T -            torch.max(attn_weights_T, dim=-1, keepdim=True)[0])        if self.clamp_min_for_underflow:            # Do not increase -50000, data type half has quite limited range            attn_weights_l = torch.clamp(attn_weights_l, min=-MAX_CLAMP_VALUE)        if self.clamp_max_for_overflow:            # Do not increase 50000, data type half has quite limited range            attn_weights_l = torch.clamp(attn_weights_l, max=MAX_CLAMP_VALUE)        attn_weights_l = attn_weights_l.softmax(dim=-1)        if attention_mask_l is not None:            assert (attention_mask_l.dim() == 2)            attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1)            attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len)            attention_mask = attention_mask.masked_fill(                attention_mask == 0, -9e15)            if attention_mask.size() != (bsz, 1, tgt_len, src_len):                raise ValueError('Attention mask should be of '                                 f'size {(bsz, 1, tgt_len, src_len)}')            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,                                             src_len) + attention_mask            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,                                             src_len)        attn_weights_v = nn.functional.softmax(attn_weights, dim=-1)        attn_probs_v = F.dropout(            attn_weights_v, p=self.dropout, training=self.training)        attn_probs_l = F.dropout(            attn_weights_l, p=self.dropout, training=self.training)        attn_output_v = torch.bmm(attn_probs_v, value_l_states)        attn_output_l = torch.bmm(attn_probs_l, value_v_states)        if attn_output_v.size() != (bsz * self.num_heads, tgt_len,                                    self.head_dim):            raise ValueError(                '`attn_output_v` should be of '                f'size {(bsz, self.num_heads, tgt_len, self.head_dim)}, '                f'but is {attn_output_v.size()}')        if attn_output_l.size() != (bsz * self.num_heads, src_len,                                    self.head_dim):            raise ValueError(                '`attn_output_l` should be of size '                f'{(bsz, self.num_heads, src_len, self.head_dim)}, '                f'but is {attn_output_l.size()}')        attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len,                                           self.head_dim)        attn_output_v = attn_output_v.transpose(1, 2)        attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)        attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len,                                           self.head_dim)        attn_output_l = attn_output_l.transpose(1, 2)        attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)        attn_output_v = self.out_v_proj(attn_output_v)        attn_output_l = self.out_l_proj(attn_output_l)        return attn_output_v, attn_output_lclass BiAttentionBlock(nn.Module):    """BiAttentionBlock Module:    First, multi-level visual features are concat; Then the concat visual    feature and lang feature are fused by attention; Finally the newly visual    feature are split into multi levels.    """    def __init__(self,                 v_dim: int,                 l_dim: int,                 embed_dim: int,                 num_heads: int,                 dropout: float = 0.1,                 drop_path: float = .0,                 init_values: float = 1e-4):        super().__init__()        # pre layer norm        self.layer_norm_v = nn.LayerNorm(v_dim)        self.layer_norm_l = nn.LayerNorm(l_dim)        self.attn = BiMultiHeadAttention(            v_dim=v_dim,            l_dim=l_dim,            embed_dim=embed_dim,            num_heads=num_heads,            dropout=dropout)        # add layer scale for training stability        self.drop_path = DropPath(            drop_path) if drop_path > 0. else nn.Identity()        self.gamma_v = nn.Parameter(            init_values * torch.ones(v_dim), requires_grad=True)        self.gamma_l = nn.Parameter(            init_values * torch.ones(l_dim), requires_grad=True)    def forward(self,                visual_features: list,                lang_feature: Tensor,                attention_mask_l=None):        size_per_level, visual_features_flatten = [], []        for i, feat_per_level in enumerate(visual_features):            bs, c, h, w = feat_per_level.shape            size_per_level.append([h, w])            feat = permute_and_flatten(feat_per_level, bs, -1, c, h, w)            visual_features_flatten.append(feat)        visual_features_flatten = torch.cat(visual_features_flatten, dim=1)        new_v, new_lang_feature = self.single_attention_call(            visual_features_flatten,            lang_feature,            attention_mask_l=attention_mask_l)        # [bs, N, C] -> [bs, C, N]        new_v = new_v.transpose(1, 2).contiguous()        start = 0        fusion_visual_features = []        for (h, w) in size_per_level:            new_v_per_level = new_v[:, :,                                    start:start + h * w].view(bs, -1, h,                                                              w).contiguous()            fusion_visual_features.append(new_v_per_level)            start += h * w        return fusion_visual_features, new_lang_feature    def single_attention_call(self, visual, lang, attention_mask_l=None):        visual = self.layer_norm_v(visual)        lang = self.layer_norm_l(lang)        delta_v, delta_l = self.attn(            visual, lang, attention_mask_l=attention_mask_l)        # visual, lang = visual + delta_v, l + delta_l        visual = visual + self.drop_path(self.gamma_v * delta_v)        lang = lang + self.drop_path(self.gamma_l * delta_l)        return visual, langclass VLFuse(nn.Module):    """Early Fusion Module."""    def __init__(self,                 v_dim: int = 256,                 l_dim: int = 768,                 embed_dim: int = 2048,                 num_heads: int = 8,                 dropout: float = 0.1,                 drop_path: float = 0.0,                 use_checkpoint: bool = False):        super().__init__()        # bi-direction (text->image, image->text)        self.use_checkpoint = use_checkpoint        self.b_attn = BiAttentionBlock(            v_dim=v_dim,            l_dim=l_dim,            embed_dim=embed_dim,            num_heads=num_heads,            dropout=dropout,            drop_path=drop_path,            init_values=1.0 / 6.0)    def forward(self, x):        visual_features = x['visual']        language_dict_features = x['lang']        if self.use_checkpoint:            fused_visual_features, language_features = checkpoint.checkpoint(                self.b_attn, visual_features, language_dict_features['hidden'],                language_dict_features['masks'])        else:            fused_visual_features, language_features = self.b_attn(                visual_features, language_dict_features['hidden'],                language_dict_features['masks'])        language_dict_features['hidden'] = language_features        fused_language_dict_features = language_dict_features        features_dict = {            'visual': fused_visual_features,            'lang': fused_language_dict_features        }        return features_dictclass BertEncoderLayer(BertPreTrainedModel):    """Modified from transformers.models.bert.modeling_bert.BertLayer."""    def __init__(self,                 config,                 clamp_min_for_underflow: bool = False,                 clamp_max_for_overflow: bool = False):        super().__init__(config)        self.config = config        self.chunk_size_feed_forward = config.chunk_size_feed_forward        self.seq_len_dim = 1        self.attention = BertAttention(config, clamp_min_for_underflow,                                       clamp_max_for_overflow)        self.intermediate = BertIntermediate(config)        self.output = BertOutput(config)    def forward(self, inputs):        language_dict_features = inputs['lang']        hidden_states = language_dict_features['hidden']        attention_mask = language_dict_features['masks']        device = hidden_states.device        input_shape = hidden_states.size()[:-1]        # We can provide a self-attention mask of dimensions        # [batch_size, from_seq_length, to_seq_length]        # ourselves in which case we just need to make it        # broadcastable to all heads.        extended_attention_mask = self.get_extended_attention_mask(            attention_mask, input_shape, device)        self_attention_outputs = self.attention(            hidden_states,            extended_attention_mask,            None,            output_attentions=False,            past_key_value=None,        )        attention_output = self_attention_outputs[0]        outputs = self_attention_outputs[            1:]  # add self attentions if we output attention weights        layer_output = apply_chunking_to_forward(self.feed_forward_chunk,                                                 self.chunk_size_feed_forward,                                                 self.seq_len_dim,                                                 attention_output)        outputs = (layer_output, ) + outputs        hidden_states = outputs[0]        language_dict_features['hidden'] = hidden_states        features_dict = {            'visual': inputs['visual'],            'lang': language_dict_features        }        return features_dict    def feed_forward_chunk(self, attention_output):        intermediate_output = self.intermediate(attention_output)        layer_output = self.output(intermediate_output, attention_output)        return layer_output# The following code is the same as the Huggingface code,# with the only difference being the additional clamp operation.class BertSelfAttention(nn.Module):    """BERT self-attention layer from Huggingface transformers.    Compared to the BertSelfAttention of Huggingface, only add the clamp.    """    def __init__(self,                 config,                 clamp_min_for_underflow: bool = False,                 clamp_max_for_overflow: bool = False):        super().__init__()        if config.hidden_size % config.num_attention_heads != 0 and \                not hasattr(config, 'embedding_size'):            raise ValueError(f'The hidden size ({config.hidden_size}) is '                             'not a multiple of the number of attention '                             f'heads ({config.num_attention_heads})')        self.num_attention_heads = config.num_attention_heads        self.attention_head_size = int(config.hidden_size /                                       config.num_attention_heads)        self.all_head_size = self.num_attention_heads * \            self.attention_head_size        self.query = nn.Linear(config.hidden_size, self.all_head_size)        self.key = nn.Linear(config.hidden_size, self.all_head_size)        self.value = nn.Linear(config.hidden_size, self.all_head_size)        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)        self.position_embedding_type = getattr(config,                                               'position_embedding_type',                                               'absolute')        if self.position_embedding_type == 'relative_key' or \                self.position_embedding_type == 'relative_key_query':            self.max_position_embeddings = config.max_position_embeddings            self.distance_embedding = nn.Embedding(                2 * config.max_position_embeddings - 1,                self.attention_head_size)        self.clamp_min_for_underflow = clamp_min_for_underflow        self.clamp_max_for_overflow = clamp_max_for_overflow        self.is_decoder = config.is_decoder    def transpose_for_scores(self, x):        new_x_shape = x.size()[:-1] + (self.num_attention_heads,                                       self.attention_head_size)        x = x.view(*new_x_shape)        return x.permute(0, 2, 1, 3)    def forward(        self,        hidden_states,        attention_mask=None,        head_mask=None,        encoder_hidden_states=None,        encoder_attention_mask=None,        past_key_value=None,        output_attentions=False,    ):        mixed_query_layer = self.query(hidden_states)        # If this is instantiated as a cross-attention module, the keys        # and values come from an encoder; the attention mask needs to be        # such that the encoder's padding tokens are not attended to.        is_cross_attention = encoder_hidden_states is not None        if is_cross_attention and past_key_value is not None:            # reuse k,v, cross_attentions            key_layer = past_key_value[0]            value_layer = past_key_value[1]            attention_mask = encoder_attention_mask        elif is_cross_attention:            key_layer = self.transpose_for_scores(                self.key(encoder_hidden_states))            value_layer = self.transpose_for_scores(                self.value(encoder_hidden_states))            attention_mask = encoder_attention_mask        elif past_key_value is not None:            key_layer = self.transpose_for_scores(self.key(hidden_states))            value_layer = self.transpose_for_scores(self.value(hidden_states))            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)        else:            key_layer = self.transpose_for_scores(self.key(hidden_states))            value_layer = self.transpose_for_scores(self.value(hidden_states))        query_layer = self.transpose_for_scores(mixed_query_layer)        if self.is_decoder:            past_key_value = (key_layer, value_layer)        # Take the dot product between "query" and "key"        # to get the raw attention scores.        attention_scores = torch.matmul(query_layer,                                        key_layer.transpose(-1, -2))        if self.position_embedding_type == 'relative_key' or \                self.position_embedding_type == 'relative_key_query':            seq_length = hidden_states.size()[1]            position_ids_l = torch.arange(                seq_length, dtype=torch.long,                device=hidden_states.device).view(-1, 1)            position_ids_r = torch.arange(                seq_length, dtype=torch.long,                device=hidden_states.device).view(1, -1)            distance = position_ids_l - position_ids_r            positional_embedding = self.distance_embedding(                distance + self.max_position_embeddings - 1)            positional_embedding = positional_embedding.to(                dtype=query_layer.dtype)  # fp16 compatibility            if self.position_embedding_type == 'relative_key':                relative_position_scores = torch.einsum(                    'bhld,lrd->bhlr', query_layer, positional_embedding)                attention_scores = attention_scores + relative_position_scores            elif self.position_embedding_type == 'relative_key_query':                relative_position_scores_query = torch.einsum(                    'bhld,lrd->bhlr', query_layer, positional_embedding)                relative_position_scores_key = torch.einsum(                    'bhrd,lrd->bhlr', key_layer, positional_embedding)                attention_scores = attention_scores + \                    relative_position_scores_query + \                    relative_position_scores_key        attention_scores = attention_scores / math.sqrt(            self.attention_head_size)        if self.clamp_min_for_underflow:            attention_scores = torch.clamp(                attention_scores, min=-MAX_CLAMP_VALUE            )  # Do not increase -50000, data type half has quite limited range        if self.clamp_max_for_overflow:            attention_scores = torch.clamp(                attention_scores, max=MAX_CLAMP_VALUE            )  # Do not increase 50000, data type half has quite limited range        if attention_mask is not None:            # Apply the attention mask is            # (precomputed for all layers in BertModel forward() function)            attention_scores = attention_scores + attention_mask        # Normalize the attention scores to probabilities.        attention_probs = nn.Softmax(dim=-1)(attention_scores)        # This is actually dropping out entire tokens to attend to, which might        # seem a bit unusual, but is taken from the original Transformer paper.        attention_probs = self.dropout(attention_probs)        # Mask heads if we want to        if head_mask is not None:            attention_probs = attention_probs * head_mask        context_layer = torch.matmul(attention_probs, value_layer)        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()        new_context_layer_shape = context_layer.size()[:-2] + (            self.all_head_size, )        context_layer = context_layer.view(*new_context_layer_shape)        outputs = (context_layer,                   attention_probs) if output_attentions else (context_layer, )        if self.is_decoder:            outputs = outputs + (past_key_value, )        return outputsclass BertAttention(HFBertAttention):    """BertAttention is made up of self-attention and intermediate+output.    Compared to the BertAttention of Huggingface, only add the clamp.    """    def __init__(self,                 config,                 clamp_min_for_underflow: bool = False,                 clamp_max_for_overflow: bool = False):        super().__init__(config)        self.self = BertSelfAttention(config, clamp_min_for_underflow,                                      clamp_max_for_overflow)class BertIntermediate(HFBertIntermediate):    def forward(self, hidden_states):        hidden_states = self.dense(hidden_states)        hidden_states = clamp_values(hidden_states)        hidden_states = self.intermediate_act_fn(hidden_states)        hidden_states = clamp_values(hidden_states)        return hidden_statesclass BertOutput(HFBertOutput):    def forward(self, hidden_states, input_tensor):        hidden_states = self.dense(hidden_states)        hidden_states = self.dropout(hidden_states)        hidden_states = clamp_values(hidden_states)        hidden_states = self.LayerNorm(hidden_states + input_tensor)        hidden_states = clamp_values(hidden_states)        return hidden_states
 |