transformer_decoder.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. from mmdet.registry import MODELS
  5. from .language_model import LanguageEncoder
  6. from .transformer_blocks import (MLP, Conv2d, CrossAttentionLayer, FFNLayer,
  7. PositionEmbeddingSine, SelfAttentionLayer)
  8. from .utils import is_lower_torch_version
  9. def vl_similarity(image_feat, text_feat, temperature=1):
  10. logits = torch.matmul(image_feat, text_feat.t())
  11. logits = temperature.exp().clamp(max=100) * logits
  12. return logits
  13. @MODELS.register_module()
  14. class XDecoderTransformerDecoder(nn.Module):
  15. def __init__(
  16. self,
  17. in_channels=512,
  18. hidden_dim: int = 512,
  19. dim_proj: int = 512,
  20. num_queries: int = 101,
  21. max_token_num: int = 77,
  22. nheads: int = 8,
  23. dim_feedforward: int = 2048,
  24. decoder_layers: int = 9,
  25. pre_norm: bool = False,
  26. mask_dim: int = 512,
  27. task: str = 'semseg',
  28. captioning_step: int = 50,
  29. ):
  30. super().__init__()
  31. # positional encoding
  32. self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
  33. # define transformer decoder here
  34. self.num_heads = nheads
  35. self.num_layers = decoder_layers
  36. self.max_token_num = max_token_num
  37. self.transformer_self_attention_layers = nn.ModuleList()
  38. self.transformer_cross_attention_layers = nn.ModuleList()
  39. self.transformer_ffn_layers = nn.ModuleList()
  40. for _ in range(self.num_layers):
  41. self.transformer_self_attention_layers.append(
  42. SelfAttentionLayer(
  43. d_model=hidden_dim,
  44. nhead=nheads,
  45. dropout=0.0,
  46. normalize_before=pre_norm,
  47. ))
  48. self.transformer_cross_attention_layers.append(
  49. CrossAttentionLayer(
  50. d_model=hidden_dim,
  51. nhead=nheads,
  52. dropout=0.0,
  53. normalize_before=pre_norm,
  54. ))
  55. self.transformer_ffn_layers.append(
  56. FFNLayer(
  57. d_model=hidden_dim,
  58. dim_feedforward=dim_feedforward,
  59. dropout=0.0,
  60. normalize_before=pre_norm,
  61. ))
  62. self.decoder_norm = nn.LayerNorm(hidden_dim)
  63. self.num_queries = num_queries
  64. # learnable query features
  65. self.query_feat = nn.Embedding(num_queries, hidden_dim)
  66. # learnable query p.e.
  67. self.query_embed = nn.Embedding(num_queries, hidden_dim)
  68. # level embedding (always use 3 scales)
  69. self.num_feature_levels = 3
  70. self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
  71. self.input_proj = nn.ModuleList()
  72. for _ in range(self.num_feature_levels):
  73. if in_channels != hidden_dim:
  74. self.input_proj.append(
  75. Conv2d(in_channels, hidden_dim, kernel_size=1))
  76. else:
  77. self.input_proj.append(nn.Sequential())
  78. self.task = task
  79. # output FFNs
  80. self.lang_encoder = LanguageEncoder()
  81. self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
  82. self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
  83. # for caption and ref-caption
  84. self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
  85. self.pos_embed_caping = nn.Embedding(max_token_num, hidden_dim)
  86. self.captioning_step = captioning_step
  87. # register self_attn_mask to avoid information leakage,
  88. # it includes interaction between object query, class query and
  89. # caption query
  90. self_attn_mask = torch.zeros((1, num_queries + max_token_num,
  91. num_queries + max_token_num)).bool()
  92. # object+class query does not attend with caption query.
  93. self_attn_mask[:, :num_queries, num_queries:] = True
  94. # caption query only attend with previous token.
  95. self_attn_mask[:, num_queries:, num_queries:] = torch.triu(
  96. torch.ones((1, max_token_num, max_token_num)), diagonal=1).bool()
  97. # object query does not attend with class query.
  98. self_attn_mask[:, :num_queries - 1, num_queries - 1:num_queries] = True
  99. # class query does not attend with object query.
  100. self_attn_mask[:, num_queries - 1:num_queries, :num_queries - 1] = True
  101. self.register_buffer('self_attn_mask', self_attn_mask)
  102. def forward(self, x, mask_features, extra=None):
  103. if self.task == 'caption':
  104. return self.forward_caption(x, mask_features, extra)
  105. assert len(x) == self.num_feature_levels
  106. src = []
  107. pos = []
  108. size_list = []
  109. for i in range(self.num_feature_levels):
  110. size_list.append(x[i].shape[-2:])
  111. pos.append(self.pe_layer(x[i], None).flatten(2))
  112. src.append(self.input_proj[i](x[i]).flatten(2) +
  113. self.level_embed.weight[i][None, :, None])
  114. # flatten NxCxHxW to HWxNxC
  115. pos[-1] = pos[-1].permute(2, 0, 1)
  116. src[-1] = src[-1].permute(2, 0, 1)
  117. _, bs, _ = src[0].shape
  118. query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
  119. output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
  120. predictions_mask = []
  121. predictions_class_embed = []
  122. if self.task == 'ref-seg':
  123. self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self.
  124. num_queries].repeat(
  125. output.shape[1] *
  126. self.num_heads, 1, 1)
  127. grounding_tokens = extra['grounding_tokens']
  128. _grounding_tokens = grounding_tokens.detach().clone()
  129. # initialize with negative attention at the beginning.
  130. pad_tgt_mask = torch.ones(
  131. (1, self.num_queries + (self.num_queries - 1) +
  132. len(grounding_tokens), self.num_queries +
  133. (self.num_queries - 1) + len(grounding_tokens)),
  134. device=self_tgt_mask.device).bool().repeat(
  135. output.shape[1] * self.num_heads, 1, 1)
  136. pad_tgt_mask[:, :self.num_queries, :self.
  137. num_queries] = self_tgt_mask
  138. # grounding tokens could attend with eatch other
  139. pad_tgt_mask[:, self.num_queries:, self.num_queries:] = False
  140. self_tgt_mask = pad_tgt_mask
  141. output = torch.cat((output, output[:-1]), dim=0)
  142. # also pad language embdding to fix embedding
  143. query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0)
  144. else:
  145. self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self.
  146. num_queries].repeat(
  147. output.shape[1] *
  148. self.num_heads, 1, 1)
  149. results = self.forward_prediction_heads(
  150. output, mask_features, attn_mask_target_size=size_list[0])
  151. attn_mask = results['attn_mask']
  152. predictions_class_embed.append(results['class_embed'])
  153. predictions_mask.append(results['outputs_mask'])
  154. for i in range(self.num_layers):
  155. level_index = i % self.num_feature_levels
  156. attn_mask[torch.where(
  157. attn_mask.sum(-1) == attn_mask.shape[-1])] = False
  158. # attention: cross-attention first
  159. output, avg_attn = self.transformer_cross_attention_layers[i](
  160. output,
  161. src[level_index],
  162. memory_mask=attn_mask,
  163. # here we do not apply masking on padded region
  164. memory_key_padding_mask=None,
  165. pos=pos[level_index],
  166. query_pos=query_embed)
  167. if self.task == 'ref-seg':
  168. output = torch.cat((output, _grounding_tokens), dim=0)
  169. query_embed = torch.cat((query_embed, grounding_tokens), dim=0)
  170. output = self.transformer_self_attention_layers[i](
  171. output,
  172. tgt_mask=self_tgt_mask,
  173. tgt_key_padding_mask=None,
  174. query_pos=query_embed)
  175. output = self.transformer_ffn_layers[i](output)
  176. if self.task == 'ref-seg':
  177. _grounding_tokens = output[-len(_grounding_tokens):]
  178. output = output[:-len(_grounding_tokens)]
  179. query_embed = query_embed[:-len(_grounding_tokens)]
  180. results = self.forward_prediction_heads(
  181. output,
  182. mask_features,
  183. attn_mask_target_size=size_list[(i + 1) %
  184. self.num_feature_levels])
  185. attn_mask = results['attn_mask']
  186. predictions_mask.append(results['outputs_mask'])
  187. predictions_class_embed.append(results['class_embed'])
  188. out = {
  189. 'pred_masks': predictions_mask[-1],
  190. 'pred_class_embed': predictions_class_embed[-1],
  191. }
  192. if self.task == 'ref-seg':
  193. mask_pred_results = []
  194. outputs_class = []
  195. for idx in range(mask_features.shape[0]): # batch size
  196. pred_gmasks = out['pred_masks'][idx, self.num_queries:2 *
  197. self.num_queries - 1]
  198. v_emb = predictions_class_embed[-1][idx, self.num_queries:2 *
  199. self.num_queries - 1]
  200. t_emb = extra['class_emb']
  201. t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
  202. v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
  203. temperature = self.lang_encoder.logit_scale
  204. out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
  205. matched_id = out_prob.max(0)[1]
  206. mask_pred_results += [pred_gmasks[matched_id, :, :]]
  207. outputs_class += [out_prob[matched_id, :]]
  208. out['pred_masks'] = mask_pred_results
  209. out['pred_logits'] = outputs_class
  210. elif self.task == 'retrieval':
  211. t_emb = extra['class_emb']
  212. temperature = self.lang_encoder.logit_scale
  213. v_emb = out['pred_class_embed'][:, -1, :]
  214. v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
  215. logits = vl_similarity(v_emb, t_emb, temperature)
  216. out['pred_logits'] = logits
  217. elif self.task in ['semseg', 'instance', 'panoptic']:
  218. outputs_class = self.lang_encoder.compute_similarity(
  219. out['pred_class_embed'])
  220. out['pred_logits'] = outputs_class
  221. return out
  222. def forward_caption(self, x, mask_features, extra=None):
  223. assert len(x) == self.num_feature_levels
  224. src = []
  225. pos = []
  226. size_list = []
  227. for i in range(self.num_feature_levels):
  228. size_list.append(x[i].shape[-2:])
  229. pos.append(self.pe_layer(x[i], None).flatten(2))
  230. src.append(self.input_proj[i](x[i]).flatten(2) +
  231. self.level_embed.weight[i][None, :, None])
  232. # flatten NxCxHxW to HWxNxC
  233. pos[-1] = pos[-1].permute(2, 0, 1)
  234. src[-1] = src[-1].permute(2, 0, 1)
  235. _, bs, _ = src[0].shape
  236. # QxNxC
  237. query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
  238. query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
  239. lang_token = extra['start_token'].repeat(bs, 1)
  240. pos_embed = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
  241. # prepare token embedding for evaluation
  242. token_embs = self.lang_encoder.lang_encoder.token_embedding.weight
  243. for cap_idx in range(0, self.captioning_step):
  244. lang_embed = self.lang_encoder.forward_language(
  245. (lang_token, ), with_cls_embed=False)[1].transpose(0, 1)
  246. # concat object query, class token and caption token.
  247. output = torch.cat((query_feat, lang_embed), dim=0)
  248. lang_embed += pos_embed
  249. query_embed = torch.cat((query_embed_, lang_embed), dim=0)
  250. # prediction heads on learnable query features
  251. results = self.forward_prediction_heads(
  252. output, mask_features, attn_mask_target_size=size_list[0])
  253. attn_mask = results['attn_mask']
  254. for i in range(self.num_layers):
  255. level_index = i % self.num_feature_levels
  256. attn_mask[torch.where(
  257. attn_mask.sum(-1) == attn_mask.shape[-1])] = False
  258. attn_mask = torch.cat(
  259. (attn_mask,
  260. torch.zeros_like(attn_mask[:, :self.max_token_num, :])),
  261. dim=1)
  262. self_tgt_mask = self.self_attn_mask.repeat(
  263. output.shape[1] * self.num_heads, 1, 1)
  264. if 'grounding_mask' in extra:
  265. bs, nq, wh = attn_mask.shape
  266. assert bs == self.num_heads, 'Only support single ' \
  267. 'image referring captioning.'
  268. grounding_mask = extra['grounding_mask']
  269. attn_mask = attn_mask.reshape(bs, nq, size_list[i % 3][0],
  270. size_list[i % 3][1])
  271. grounding_mask = F.interpolate(
  272. grounding_mask.float(),
  273. size_list[i % 3],
  274. mode='nearest').bool()[0, 0]
  275. attn_mask[:, self.num_queries:, grounding_mask] = True
  276. attn_mask = attn_mask.reshape(bs, nq, wh)
  277. # attention: cross-attention first
  278. output, avg_attn = self.transformer_cross_attention_layers[i](
  279. output,
  280. src[level_index],
  281. memory_mask=attn_mask,
  282. # here we do not apply masking on padded region
  283. memory_key_padding_mask=None,
  284. pos=pos[level_index],
  285. query_pos=query_embed)
  286. output = self.transformer_self_attention_layers[i](
  287. output,
  288. tgt_mask=self_tgt_mask,
  289. tgt_key_padding_mask=None,
  290. query_pos=query_embed)
  291. output = self.transformer_ffn_layers[i](output)
  292. results = self.forward_prediction_heads(
  293. output,
  294. mask_features,
  295. attn_mask_target_size=size_list[(i + 1) %
  296. self.num_feature_levels])
  297. attn_mask = results['attn_mask']
  298. pred_captions = results['outputs_caption']
  299. pred_captions = pred_captions @ token_embs.t()
  300. lang_token[:, cap_idx + 1] = pred_captions[:, cap_idx].max(-1)[1]
  301. texts = self.lang_encoder.tokenizer.batch_decode(
  302. lang_token, skip_special_tokens=False)
  303. texts_new = []
  304. for x in texts:
  305. x = x.split('<|endoftext|>')[0]
  306. x = x.replace('<|endoftext|>', '')
  307. x = x.replace('<|startoftext|>', '')
  308. x = x.strip()
  309. texts_new.append(x)
  310. out = {'pred_caption': texts_new}
  311. return out
  312. def forward_prediction_heads(self, output, mask_features,
  313. attn_mask_target_size):
  314. decoder_output = self.decoder_norm(output)
  315. decoder_output = decoder_output.transpose(0, 1)
  316. if self.task == 'caption':
  317. outputs_caption = decoder_output[:, self.
  318. num_queries:] @ self.caping_embed
  319. # recompute class token output.
  320. norm_decoder_output = decoder_output / (
  321. decoder_output.norm(dim=-1, keepdim=True) + 1e-7)
  322. obj_token = norm_decoder_output[:, :self.num_queries - 1]
  323. cls_token = norm_decoder_output[:,
  324. self.num_queries - 1:self.num_queries]
  325. sim = (cls_token @ obj_token.transpose(1, 2)).softmax(-1)[:, 0, :,
  326. None]
  327. cls_token = (sim * decoder_output[:, :self.num_queries - 1]).sum(
  328. dim=1, keepdim=True)
  329. if self.task == 'ref-seg':
  330. decoder_output = torch.cat(
  331. (decoder_output[:, :self.num_queries - 1], cls_token,
  332. decoder_output[:, self.num_queries:2 * self.num_queries - 1]),
  333. dim=1)
  334. else:
  335. decoder_output = torch.cat(
  336. (decoder_output[:, :self.num_queries - 1], cls_token), dim=1)
  337. mask_embed = self.mask_embed(decoder_output)
  338. outputs_mask = torch.einsum('bqc,bchw->bqhw', mask_embed,
  339. mask_features)
  340. if is_lower_torch_version():
  341. attn_mask = F.interpolate(
  342. outputs_mask,
  343. size=attn_mask_target_size,
  344. mode='bicubic',
  345. align_corners=False)
  346. else:
  347. attn_mask = F.interpolate(
  348. outputs_mask,
  349. size=attn_mask_target_size,
  350. mode='bicubic',
  351. align_corners=False,
  352. antialias=True)
  353. attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(
  354. 1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
  355. attn_mask = attn_mask.detach()
  356. attn_mask[:, self.num_queries:self.num_queries + 1].fill_(False)
  357. if self.task == 'caption':
  358. results = {
  359. 'attn_mask': attn_mask,
  360. 'outputs_caption': outputs_caption,
  361. }
  362. return results
  363. else:
  364. class_embed = decoder_output @ self.class_embed
  365. results = {
  366. 'outputs_mask': outputs_mask,
  367. 'attn_mask': attn_mask,
  368. 'class_embed': class_embed,
  369. }
  370. return results