import torch from torch import nn from torch.nn import functional as F from mmdet.registry import MODELS from .language_model import LanguageEncoder from .transformer_blocks import (MLP, Conv2d, CrossAttentionLayer, FFNLayer, PositionEmbeddingSine, SelfAttentionLayer) from .utils import is_lower_torch_version def vl_similarity(image_feat, text_feat, temperature=1): logits = torch.matmul(image_feat, text_feat.t()) logits = temperature.exp().clamp(max=100) * logits return logits @MODELS.register_module() class XDecoderTransformerDecoder(nn.Module): def __init__( self, in_channels=512, hidden_dim: int = 512, dim_proj: int = 512, num_queries: int = 101, max_token_num: int = 77, nheads: int = 8, dim_feedforward: int = 2048, decoder_layers: int = 9, pre_norm: bool = False, mask_dim: int = 512, task: str = 'semseg', captioning_step: int = 50, ): super().__init__() # positional encoding self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True) # define transformer decoder here self.num_heads = nheads self.num_layers = decoder_layers self.max_token_num = max_token_num self.transformer_self_attention_layers = nn.ModuleList() self.transformer_cross_attention_layers = nn.ModuleList() self.transformer_ffn_layers = nn.ModuleList() for _ in range(self.num_layers): self.transformer_self_attention_layers.append( SelfAttentionLayer( d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm, )) self.transformer_cross_attention_layers.append( CrossAttentionLayer( d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm, )) self.transformer_ffn_layers.append( FFNLayer( d_model=hidden_dim, dim_feedforward=dim_feedforward, dropout=0.0, normalize_before=pre_norm, )) self.decoder_norm = nn.LayerNorm(hidden_dim) self.num_queries = num_queries # learnable query features self.query_feat = nn.Embedding(num_queries, hidden_dim) # learnable query p.e. self.query_embed = nn.Embedding(num_queries, hidden_dim) # level embedding (always use 3 scales) self.num_feature_levels = 3 self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) self.input_proj = nn.ModuleList() for _ in range(self.num_feature_levels): if in_channels != hidden_dim: self.input_proj.append( Conv2d(in_channels, hidden_dim, kernel_size=1)) else: self.input_proj.append(nn.Sequential()) self.task = task # output FFNs self.lang_encoder = LanguageEncoder() self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) # for caption and ref-caption self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) self.pos_embed_caping = nn.Embedding(max_token_num, hidden_dim) self.captioning_step = captioning_step # register self_attn_mask to avoid information leakage, # it includes interaction between object query, class query and # caption query self_attn_mask = torch.zeros((1, num_queries + max_token_num, num_queries + max_token_num)).bool() # object+class query does not attend with caption query. self_attn_mask[:, :num_queries, num_queries:] = True # caption query only attend with previous token. self_attn_mask[:, num_queries:, num_queries:] = torch.triu( torch.ones((1, max_token_num, max_token_num)), diagonal=1).bool() # object query does not attend with class query. self_attn_mask[:, :num_queries - 1, num_queries - 1:num_queries] = True # class query does not attend with object query. self_attn_mask[:, num_queries - 1:num_queries, :num_queries - 1] = True self.register_buffer('self_attn_mask', self_attn_mask) def forward(self, x, mask_features, extra=None): if self.task == 'caption': return self.forward_caption(x, mask_features, extra) assert len(x) == self.num_feature_levels src = [] pos = [] size_list = [] for i in range(self.num_feature_levels): size_list.append(x[i].shape[-2:]) pos.append(self.pe_layer(x[i], None).flatten(2)) src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) # flatten NxCxHxW to HWxNxC pos[-1] = pos[-1].permute(2, 0, 1) src[-1] = src[-1].permute(2, 0, 1) _, bs, _ = src[0].shape query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) predictions_mask = [] predictions_class_embed = [] if self.task == 'ref-seg': self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self. num_queries].repeat( output.shape[1] * self.num_heads, 1, 1) grounding_tokens = extra['grounding_tokens'] _grounding_tokens = grounding_tokens.detach().clone() # initialize with negative attention at the beginning. pad_tgt_mask = torch.ones( (1, self.num_queries + (self.num_queries - 1) + len(grounding_tokens), self.num_queries + (self.num_queries - 1) + len(grounding_tokens)), device=self_tgt_mask.device).bool().repeat( output.shape[1] * self.num_heads, 1, 1) pad_tgt_mask[:, :self.num_queries, :self. num_queries] = self_tgt_mask # grounding tokens could attend with eatch other pad_tgt_mask[:, self.num_queries:, self.num_queries:] = False self_tgt_mask = pad_tgt_mask output = torch.cat((output, output[:-1]), dim=0) # also pad language embdding to fix embedding query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) else: self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self. num_queries].repeat( output.shape[1] * self.num_heads, 1, 1) results = self.forward_prediction_heads( output, mask_features, attn_mask_target_size=size_list[0]) attn_mask = results['attn_mask'] predictions_class_embed.append(results['class_embed']) predictions_mask.append(results['outputs_mask']) for i in range(self.num_layers): level_index = i % self.num_feature_levels attn_mask[torch.where( attn_mask.sum(-1) == attn_mask.shape[-1])] = False # attention: cross-attention first output, avg_attn = self.transformer_cross_attention_layers[i]( output, src[level_index], memory_mask=attn_mask, # here we do not apply masking on padded region memory_key_padding_mask=None, pos=pos[level_index], query_pos=query_embed) if self.task == 'ref-seg': output = torch.cat((output, _grounding_tokens), dim=0) query_embed = torch.cat((query_embed, grounding_tokens), dim=0) output = self.transformer_self_attention_layers[i]( output, tgt_mask=self_tgt_mask, tgt_key_padding_mask=None, query_pos=query_embed) output = self.transformer_ffn_layers[i](output) if self.task == 'ref-seg': _grounding_tokens = output[-len(_grounding_tokens):] output = output[:-len(_grounding_tokens)] query_embed = query_embed[:-len(_grounding_tokens)] results = self.forward_prediction_heads( output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels]) attn_mask = results['attn_mask'] predictions_mask.append(results['outputs_mask']) predictions_class_embed.append(results['class_embed']) out = { 'pred_masks': predictions_mask[-1], 'pred_class_embed': predictions_class_embed[-1], } if self.task == 'ref-seg': mask_pred_results = [] outputs_class = [] for idx in range(mask_features.shape[0]): # batch size pred_gmasks = out['pred_masks'][idx, self.num_queries:2 * self.num_queries - 1] v_emb = predictions_class_embed[-1][idx, self.num_queries:2 * self.num_queries - 1] t_emb = extra['class_emb'] t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7) v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) temperature = self.lang_encoder.logit_scale out_prob = vl_similarity(v_emb, t_emb, temperature=temperature) matched_id = out_prob.max(0)[1] mask_pred_results += [pred_gmasks[matched_id, :, :]] outputs_class += [out_prob[matched_id, :]] out['pred_masks'] = mask_pred_results out['pred_logits'] = outputs_class elif self.task == 'retrieval': t_emb = extra['class_emb'] temperature = self.lang_encoder.logit_scale v_emb = out['pred_class_embed'][:, -1, :] v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) logits = vl_similarity(v_emb, t_emb, temperature) out['pred_logits'] = logits elif self.task in ['semseg', 'instance', 'panoptic']: outputs_class = self.lang_encoder.compute_similarity( out['pred_class_embed']) out['pred_logits'] = outputs_class return out def forward_caption(self, x, mask_features, extra=None): assert len(x) == self.num_feature_levels src = [] pos = [] size_list = [] for i in range(self.num_feature_levels): size_list.append(x[i].shape[-2:]) pos.append(self.pe_layer(x[i], None).flatten(2)) src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) # flatten NxCxHxW to HWxNxC pos[-1] = pos[-1].permute(2, 0, 1) src[-1] = src[-1].permute(2, 0, 1) _, bs, _ = src[0].shape # QxNxC query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) lang_token = extra['start_token'].repeat(bs, 1) pos_embed = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1) # prepare token embedding for evaluation token_embs = self.lang_encoder.lang_encoder.token_embedding.weight for cap_idx in range(0, self.captioning_step): lang_embed = self.lang_encoder.forward_language( (lang_token, ), with_cls_embed=False)[1].transpose(0, 1) # concat object query, class token and caption token. output = torch.cat((query_feat, lang_embed), dim=0) lang_embed += pos_embed query_embed = torch.cat((query_embed_, lang_embed), dim=0) # prediction heads on learnable query features results = self.forward_prediction_heads( output, mask_features, attn_mask_target_size=size_list[0]) attn_mask = results['attn_mask'] for i in range(self.num_layers): level_index = i % self.num_feature_levels attn_mask[torch.where( attn_mask.sum(-1) == attn_mask.shape[-1])] = False attn_mask = torch.cat( (attn_mask, torch.zeros_like(attn_mask[:, :self.max_token_num, :])), dim=1) self_tgt_mask = self.self_attn_mask.repeat( output.shape[1] * self.num_heads, 1, 1) if 'grounding_mask' in extra: bs, nq, wh = attn_mask.shape assert bs == self.num_heads, 'Only support single ' \ 'image referring captioning.' grounding_mask = extra['grounding_mask'] attn_mask = attn_mask.reshape(bs, nq, size_list[i % 3][0], size_list[i % 3][1]) grounding_mask = F.interpolate( grounding_mask.float(), size_list[i % 3], mode='nearest').bool()[0, 0] attn_mask[:, self.num_queries:, grounding_mask] = True attn_mask = attn_mask.reshape(bs, nq, wh) # attention: cross-attention first output, avg_attn = self.transformer_cross_attention_layers[i]( output, src[level_index], memory_mask=attn_mask, # here we do not apply masking on padded region memory_key_padding_mask=None, pos=pos[level_index], query_pos=query_embed) output = self.transformer_self_attention_layers[i]( output, tgt_mask=self_tgt_mask, tgt_key_padding_mask=None, query_pos=query_embed) output = self.transformer_ffn_layers[i](output) results = self.forward_prediction_heads( output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels]) attn_mask = results['attn_mask'] pred_captions = results['outputs_caption'] pred_captions = pred_captions @ token_embs.t() lang_token[:, cap_idx + 1] = pred_captions[:, cap_idx].max(-1)[1] texts = self.lang_encoder.tokenizer.batch_decode( lang_token, skip_special_tokens=False) texts_new = [] for x in texts: x = x.split('<|endoftext|>')[0] x = x.replace('<|endoftext|>', '') x = x.replace('<|startoftext|>', '') x = x.strip() texts_new.append(x) out = {'pred_caption': texts_new} return out def forward_prediction_heads(self, output, mask_features, attn_mask_target_size): decoder_output = self.decoder_norm(output) decoder_output = decoder_output.transpose(0, 1) if self.task == 'caption': outputs_caption = decoder_output[:, self. num_queries:] @ self.caping_embed # recompute class token output. norm_decoder_output = decoder_output / ( decoder_output.norm(dim=-1, keepdim=True) + 1e-7) obj_token = norm_decoder_output[:, :self.num_queries - 1] cls_token = norm_decoder_output[:, self.num_queries - 1:self.num_queries] sim = (cls_token @ obj_token.transpose(1, 2)).softmax(-1)[:, 0, :, None] cls_token = (sim * decoder_output[:, :self.num_queries - 1]).sum( dim=1, keepdim=True) if self.task == 'ref-seg': decoder_output = torch.cat( (decoder_output[:, :self.num_queries - 1], cls_token, decoder_output[:, self.num_queries:2 * self.num_queries - 1]), dim=1) else: decoder_output = torch.cat( (decoder_output[:, :self.num_queries - 1], cls_token), dim=1) mask_embed = self.mask_embed(decoder_output) outputs_mask = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_features) if is_lower_torch_version(): attn_mask = F.interpolate( outputs_mask, size=attn_mask_target_size, mode='bicubic', align_corners=False) else: attn_mask = F.interpolate( outputs_mask, size=attn_mask_target_size, mode='bicubic', align_corners=False, antialias=True) attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat( 1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() attn_mask = attn_mask.detach() attn_mask[:, self.num_queries:self.num_queries + 1].fill_(False) if self.task == 'caption': results = { 'attn_mask': attn_mask, 'outputs_caption': outputs_caption, } return results else: class_embed = decoder_output @ self.class_embed results = { 'outputs_mask': outputs_mask, 'attn_mask': attn_mask, 'class_embed': class_embed, } return results