123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439 |
- import torch
- from torch import nn
- from torch.nn import functional as F
- from mmdet.registry import MODELS
- from .language_model import LanguageEncoder
- from .transformer_blocks import (MLP, Conv2d, CrossAttentionLayer, FFNLayer,
- PositionEmbeddingSine, SelfAttentionLayer)
- from .utils import is_lower_torch_version
- def vl_similarity(image_feat, text_feat, temperature=1):
- logits = torch.matmul(image_feat, text_feat.t())
- logits = temperature.exp().clamp(max=100) * logits
- return logits
- @MODELS.register_module()
- class XDecoderTransformerDecoder(nn.Module):
- def __init__(
- self,
- in_channels=512,
- hidden_dim: int = 512,
- dim_proj: int = 512,
- num_queries: int = 101,
- max_token_num: int = 77,
- nheads: int = 8,
- dim_feedforward: int = 2048,
- decoder_layers: int = 9,
- pre_norm: bool = False,
- mask_dim: int = 512,
- task: str = 'semseg',
- captioning_step: int = 50,
- ):
- super().__init__()
- # positional encoding
- self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
- # define transformer decoder here
- self.num_heads = nheads
- self.num_layers = decoder_layers
- self.max_token_num = max_token_num
- self.transformer_self_attention_layers = nn.ModuleList()
- self.transformer_cross_attention_layers = nn.ModuleList()
- self.transformer_ffn_layers = nn.ModuleList()
- for _ in range(self.num_layers):
- self.transformer_self_attention_layers.append(
- SelfAttentionLayer(
- d_model=hidden_dim,
- nhead=nheads,
- dropout=0.0,
- normalize_before=pre_norm,
- ))
- self.transformer_cross_attention_layers.append(
- CrossAttentionLayer(
- d_model=hidden_dim,
- nhead=nheads,
- dropout=0.0,
- normalize_before=pre_norm,
- ))
- self.transformer_ffn_layers.append(
- FFNLayer(
- d_model=hidden_dim,
- dim_feedforward=dim_feedforward,
- dropout=0.0,
- normalize_before=pre_norm,
- ))
- self.decoder_norm = nn.LayerNorm(hidden_dim)
- self.num_queries = num_queries
- # learnable query features
- self.query_feat = nn.Embedding(num_queries, hidden_dim)
- # learnable query p.e.
- self.query_embed = nn.Embedding(num_queries, hidden_dim)
- # level embedding (always use 3 scales)
- self.num_feature_levels = 3
- self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
- self.input_proj = nn.ModuleList()
- for _ in range(self.num_feature_levels):
- if in_channels != hidden_dim:
- self.input_proj.append(
- Conv2d(in_channels, hidden_dim, kernel_size=1))
- else:
- self.input_proj.append(nn.Sequential())
- self.task = task
- # output FFNs
- self.lang_encoder = LanguageEncoder()
- self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
- self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
- # for caption and ref-caption
- self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
- self.pos_embed_caping = nn.Embedding(max_token_num, hidden_dim)
- self.captioning_step = captioning_step
- # register self_attn_mask to avoid information leakage,
- # it includes interaction between object query, class query and
- # caption query
- self_attn_mask = torch.zeros((1, num_queries + max_token_num,
- num_queries + max_token_num)).bool()
- # object+class query does not attend with caption query.
- self_attn_mask[:, :num_queries, num_queries:] = True
- # caption query only attend with previous token.
- self_attn_mask[:, num_queries:, num_queries:] = torch.triu(
- torch.ones((1, max_token_num, max_token_num)), diagonal=1).bool()
- # object query does not attend with class query.
- self_attn_mask[:, :num_queries - 1, num_queries - 1:num_queries] = True
- # class query does not attend with object query.
- self_attn_mask[:, num_queries - 1:num_queries, :num_queries - 1] = True
- self.register_buffer('self_attn_mask', self_attn_mask)
- def forward(self, x, mask_features, extra=None):
- if self.task == 'caption':
- return self.forward_caption(x, mask_features, extra)
- assert len(x) == self.num_feature_levels
- src = []
- pos = []
- size_list = []
- for i in range(self.num_feature_levels):
- size_list.append(x[i].shape[-2:])
- pos.append(self.pe_layer(x[i], None).flatten(2))
- src.append(self.input_proj[i](x[i]).flatten(2) +
- self.level_embed.weight[i][None, :, None])
- # flatten NxCxHxW to HWxNxC
- pos[-1] = pos[-1].permute(2, 0, 1)
- src[-1] = src[-1].permute(2, 0, 1)
- _, bs, _ = src[0].shape
- query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
- output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
- predictions_mask = []
- predictions_class_embed = []
- if self.task == 'ref-seg':
- self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self.
- num_queries].repeat(
- output.shape[1] *
- self.num_heads, 1, 1)
- grounding_tokens = extra['grounding_tokens']
- _grounding_tokens = grounding_tokens.detach().clone()
- # initialize with negative attention at the beginning.
- pad_tgt_mask = torch.ones(
- (1, self.num_queries + (self.num_queries - 1) +
- len(grounding_tokens), self.num_queries +
- (self.num_queries - 1) + len(grounding_tokens)),
- device=self_tgt_mask.device).bool().repeat(
- output.shape[1] * self.num_heads, 1, 1)
- pad_tgt_mask[:, :self.num_queries, :self.
- num_queries] = self_tgt_mask
- # grounding tokens could attend with eatch other
- pad_tgt_mask[:, self.num_queries:, self.num_queries:] = False
- self_tgt_mask = pad_tgt_mask
- output = torch.cat((output, output[:-1]), dim=0)
- # also pad language embdding to fix embedding
- query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0)
- else:
- self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self.
- num_queries].repeat(
- output.shape[1] *
- self.num_heads, 1, 1)
- results = self.forward_prediction_heads(
- output, mask_features, attn_mask_target_size=size_list[0])
- attn_mask = results['attn_mask']
- predictions_class_embed.append(results['class_embed'])
- predictions_mask.append(results['outputs_mask'])
- for i in range(self.num_layers):
- level_index = i % self.num_feature_levels
- attn_mask[torch.where(
- attn_mask.sum(-1) == attn_mask.shape[-1])] = False
- # attention: cross-attention first
- output, avg_attn = self.transformer_cross_attention_layers[i](
- output,
- src[level_index],
- memory_mask=attn_mask,
- # here we do not apply masking on padded region
- memory_key_padding_mask=None,
- pos=pos[level_index],
- query_pos=query_embed)
- if self.task == 'ref-seg':
- output = torch.cat((output, _grounding_tokens), dim=0)
- query_embed = torch.cat((query_embed, grounding_tokens), dim=0)
- output = self.transformer_self_attention_layers[i](
- output,
- tgt_mask=self_tgt_mask,
- tgt_key_padding_mask=None,
- query_pos=query_embed)
- output = self.transformer_ffn_layers[i](output)
- if self.task == 'ref-seg':
- _grounding_tokens = output[-len(_grounding_tokens):]
- output = output[:-len(_grounding_tokens)]
- query_embed = query_embed[:-len(_grounding_tokens)]
- results = self.forward_prediction_heads(
- output,
- mask_features,
- attn_mask_target_size=size_list[(i + 1) %
- self.num_feature_levels])
- attn_mask = results['attn_mask']
- predictions_mask.append(results['outputs_mask'])
- predictions_class_embed.append(results['class_embed'])
- out = {
- 'pred_masks': predictions_mask[-1],
- 'pred_class_embed': predictions_class_embed[-1],
- }
- if self.task == 'ref-seg':
- mask_pred_results = []
- outputs_class = []
- for idx in range(mask_features.shape[0]): # batch size
- pred_gmasks = out['pred_masks'][idx, self.num_queries:2 *
- self.num_queries - 1]
- v_emb = predictions_class_embed[-1][idx, self.num_queries:2 *
- self.num_queries - 1]
- t_emb = extra['class_emb']
- t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
- v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
- temperature = self.lang_encoder.logit_scale
- out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
- matched_id = out_prob.max(0)[1]
- mask_pred_results += [pred_gmasks[matched_id, :, :]]
- outputs_class += [out_prob[matched_id, :]]
- out['pred_masks'] = mask_pred_results
- out['pred_logits'] = outputs_class
- elif self.task == 'retrieval':
- t_emb = extra['class_emb']
- temperature = self.lang_encoder.logit_scale
- v_emb = out['pred_class_embed'][:, -1, :]
- v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
- logits = vl_similarity(v_emb, t_emb, temperature)
- out['pred_logits'] = logits
- elif self.task in ['semseg', 'instance', 'panoptic']:
- outputs_class = self.lang_encoder.compute_similarity(
- out['pred_class_embed'])
- out['pred_logits'] = outputs_class
- return out
- def forward_caption(self, x, mask_features, extra=None):
- assert len(x) == self.num_feature_levels
- src = []
- pos = []
- size_list = []
- for i in range(self.num_feature_levels):
- size_list.append(x[i].shape[-2:])
- pos.append(self.pe_layer(x[i], None).flatten(2))
- src.append(self.input_proj[i](x[i]).flatten(2) +
- self.level_embed.weight[i][None, :, None])
- # flatten NxCxHxW to HWxNxC
- pos[-1] = pos[-1].permute(2, 0, 1)
- src[-1] = src[-1].permute(2, 0, 1)
- _, bs, _ = src[0].shape
- # QxNxC
- query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
- query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
- lang_token = extra['start_token'].repeat(bs, 1)
- pos_embed = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
- # prepare token embedding for evaluation
- token_embs = self.lang_encoder.lang_encoder.token_embedding.weight
- for cap_idx in range(0, self.captioning_step):
- lang_embed = self.lang_encoder.forward_language(
- (lang_token, ), with_cls_embed=False)[1].transpose(0, 1)
- # concat object query, class token and caption token.
- output = torch.cat((query_feat, lang_embed), dim=0)
- lang_embed += pos_embed
- query_embed = torch.cat((query_embed_, lang_embed), dim=0)
- # prediction heads on learnable query features
- results = self.forward_prediction_heads(
- output, mask_features, attn_mask_target_size=size_list[0])
- attn_mask = results['attn_mask']
- for i in range(self.num_layers):
- level_index = i % self.num_feature_levels
- attn_mask[torch.where(
- attn_mask.sum(-1) == attn_mask.shape[-1])] = False
- attn_mask = torch.cat(
- (attn_mask,
- torch.zeros_like(attn_mask[:, :self.max_token_num, :])),
- dim=1)
- self_tgt_mask = self.self_attn_mask.repeat(
- output.shape[1] * self.num_heads, 1, 1)
- if 'grounding_mask' in extra:
- bs, nq, wh = attn_mask.shape
- assert bs == self.num_heads, 'Only support single ' \
- 'image referring captioning.'
- grounding_mask = extra['grounding_mask']
- attn_mask = attn_mask.reshape(bs, nq, size_list[i % 3][0],
- size_list[i % 3][1])
- grounding_mask = F.interpolate(
- grounding_mask.float(),
- size_list[i % 3],
- mode='nearest').bool()[0, 0]
- attn_mask[:, self.num_queries:, grounding_mask] = True
- attn_mask = attn_mask.reshape(bs, nq, wh)
- # attention: cross-attention first
- output, avg_attn = self.transformer_cross_attention_layers[i](
- output,
- src[level_index],
- memory_mask=attn_mask,
- # here we do not apply masking on padded region
- memory_key_padding_mask=None,
- pos=pos[level_index],
- query_pos=query_embed)
- output = self.transformer_self_attention_layers[i](
- output,
- tgt_mask=self_tgt_mask,
- tgt_key_padding_mask=None,
- query_pos=query_embed)
- output = self.transformer_ffn_layers[i](output)
- results = self.forward_prediction_heads(
- output,
- mask_features,
- attn_mask_target_size=size_list[(i + 1) %
- self.num_feature_levels])
- attn_mask = results['attn_mask']
- pred_captions = results['outputs_caption']
- pred_captions = pred_captions @ token_embs.t()
- lang_token[:, cap_idx + 1] = pred_captions[:, cap_idx].max(-1)[1]
- texts = self.lang_encoder.tokenizer.batch_decode(
- lang_token, skip_special_tokens=False)
- texts_new = []
- for x in texts:
- x = x.split('<|endoftext|>')[0]
- x = x.replace('<|endoftext|>', '')
- x = x.replace('<|startoftext|>', '')
- x = x.strip()
- texts_new.append(x)
- out = {'pred_caption': texts_new}
- return out
- def forward_prediction_heads(self, output, mask_features,
- attn_mask_target_size):
- decoder_output = self.decoder_norm(output)
- decoder_output = decoder_output.transpose(0, 1)
- if self.task == 'caption':
- outputs_caption = decoder_output[:, self.
- num_queries:] @ self.caping_embed
- # recompute class token output.
- norm_decoder_output = decoder_output / (
- decoder_output.norm(dim=-1, keepdim=True) + 1e-7)
- obj_token = norm_decoder_output[:, :self.num_queries - 1]
- cls_token = norm_decoder_output[:,
- self.num_queries - 1:self.num_queries]
- sim = (cls_token @ obj_token.transpose(1, 2)).softmax(-1)[:, 0, :,
- None]
- cls_token = (sim * decoder_output[:, :self.num_queries - 1]).sum(
- dim=1, keepdim=True)
- if self.task == 'ref-seg':
- decoder_output = torch.cat(
- (decoder_output[:, :self.num_queries - 1], cls_token,
- decoder_output[:, self.num_queries:2 * self.num_queries - 1]),
- dim=1)
- else:
- decoder_output = torch.cat(
- (decoder_output[:, :self.num_queries - 1], cls_token), dim=1)
- mask_embed = self.mask_embed(decoder_output)
- outputs_mask = torch.einsum('bqc,bchw->bqhw', mask_embed,
- mask_features)
- if is_lower_torch_version():
- attn_mask = F.interpolate(
- outputs_mask,
- size=attn_mask_target_size,
- mode='bicubic',
- align_corners=False)
- else:
- attn_mask = F.interpolate(
- outputs_mask,
- size=attn_mask_target_size,
- mode='bicubic',
- align_corners=False,
- antialias=True)
- attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(
- 1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
- attn_mask = attn_mask.detach()
- attn_mask[:, self.num_queries:self.num_queries + 1].fill_(False)
- if self.task == 'caption':
- results = {
- 'attn_mask': attn_mask,
- 'outputs_caption': outputs_caption,
- }
- return results
- else:
- class_embed = decoder_output @ self.class_embed
- results = {
- 'outputs_mask': outputs_mask,
- 'attn_mask': attn_mask,
- 'class_embed': class_embed,
- }
- return results
|