language_model.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. import os
  2. from collections import OrderedDict
  3. import torch
  4. from mmcv.cnn.bricks import DropPath
  5. from torch import nn
  6. from transformers import CLIPTokenizer
  7. from .utils import get_prompt_templates
  8. # modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/language/vlpencoder.py # noqa
  9. class LanguageEncoder(nn.Module):
  10. def __init__(
  11. self,
  12. tokenizer='openai/clip-vit-base-patch32',
  13. dim_lang=512,
  14. dim_projection=512,
  15. ):
  16. super().__init__()
  17. os.environ['TOKENIZERS_PARALLELISM'] = 'true'
  18. self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
  19. self.tokenizer.add_special_tokens(
  20. {'cls_token': self.tokenizer.eos_token})
  21. max_token_num = self.tokenizer.model_max_length
  22. self.lang_encoder = Transformer(max_token_num,
  23. self.tokenizer.vocab_size, dim_lang)
  24. self.lang_proj = nn.Parameter(torch.empty(dim_lang, dim_projection))
  25. self.max_token_num = max_token_num
  26. self.logit_scale = nn.Parameter(torch.ones([]))
  27. @torch.no_grad()
  28. def get_mean_embeds(self, class_names, name='default'):
  29. def extract_mean_emb(txts):
  30. tokens = self.tokenizer(
  31. txts,
  32. padding='max_length',
  33. truncation=True,
  34. max_length=self.max_token_num,
  35. return_tensors='pt')
  36. clss_embedding, _ = self.forward_language(
  37. (tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()),
  38. norm=True,
  39. with_token_embed=False)
  40. clss_embedding = clss_embedding.mean(dim=0)
  41. clss_embedding /= clss_embedding.norm()
  42. return clss_embedding
  43. templates = get_prompt_templates()
  44. clss_embeddings = []
  45. for clss in class_names:
  46. txts = [
  47. template.format(
  48. clss.replace('-other',
  49. '').replace('-merged',
  50. '').replace('-stuff', ''))
  51. for template in templates
  52. ]
  53. clss_embeddings.append(extract_mean_emb(txts))
  54. text_emb = torch.stack(clss_embeddings, dim=0)
  55. setattr(self, '{}_text_embeddings'.format(name), text_emb)
  56. def get_text_embeds(self, txts, name='grounding', norm=False):
  57. tokens = self.tokenizer(
  58. txts,
  59. padding='max_length',
  60. truncation=True,
  61. max_length=self.max_token_num,
  62. return_tensors='pt')
  63. tokens = {key: value.cuda() for key, value in tokens.items()}
  64. class_emb, token_emb = self.forward_language(
  65. (tokens['input_ids'], tokens['attention_mask']), norm=norm)
  66. ret = {
  67. 'tokens': tokens,
  68. 'token_emb': token_emb,
  69. 'class_emb': class_emb,
  70. }
  71. setattr(self, '{}_token_embeddings'.format(name), ret)
  72. return ret
  73. def get_sot_token(self, device):
  74. # 49406: CLIP SOT token <|startoftext|>
  75. # 77: CLIP context_length
  76. return torch.tensor([[49406] * 77], device=device)
  77. def compute_similarity(self, v_emb, name='default'):
  78. v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
  79. t_emb = getattr(self, '{}_text_embeddings'.format(name))
  80. output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(
  81. 1, 2)
  82. return output
  83. def forward_language(self,
  84. texts,
  85. norm=False,
  86. with_token_embed=True,
  87. with_cls_embed=True):
  88. x = self.lang_encoder(*texts)
  89. hidden_x = x['last_hidden_state']
  90. class_embed = None
  91. if with_cls_embed:
  92. class_embed = hidden_x[torch.arange(hidden_x.size(0)),
  93. texts[0].argmax(dim=-1)]
  94. class_embed = class_embed @ self.lang_proj
  95. if norm:
  96. class_embed = class_embed / (
  97. class_embed.norm(dim=-1, keepdim=True) + 1e-7)
  98. hidden_embed = None
  99. if with_token_embed:
  100. hidden_embed = hidden_x @ self.lang_proj
  101. if norm:
  102. hidden_embed = hidden_embed / (
  103. hidden_embed.norm(dim=-1, keepdim=True) + 1e-7)
  104. return class_embed, hidden_embed
  105. class Transformer(nn.Module):
  106. def __init__(self,
  107. context_length,
  108. vocab_size,
  109. width,
  110. layers: int = 12,
  111. heads: int = 8,
  112. drop_path: float = 0.0,
  113. autogressive: bool = True):
  114. super().__init__()
  115. self.token_embedding = nn.Embedding(vocab_size, width)
  116. self.context_length = context_length
  117. self.positional_embedding = nn.Parameter(
  118. torch.empty(self.context_length, width))
  119. self.width = width
  120. self.layers = layers
  121. self.autogressive = autogressive
  122. attn_mask = self.build_attention_mask() if autogressive else None
  123. dpr = [x.item() for x in torch.linspace(0, drop_path, layers)
  124. ] # stochastic depth decay rule
  125. self.resblocks = nn.ModuleList([
  126. ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
  127. for i in range(layers)
  128. ])
  129. self.ln_final = LayerNorm(width)
  130. @property
  131. def dim_out(self):
  132. return self.width
  133. def build_attention_mask(self):
  134. # lazily create causal attention mask,
  135. # with full attention between the vision tokens
  136. # pytorch uses additive attention mask; fill with -inf
  137. mask = torch.empty(self.context_length, self.context_length)
  138. mask.fill_(float('-inf'))
  139. mask.triu_(1) # zero out the lower diagonal
  140. return mask
  141. def forward(self, input_ids, attention_mask=None):
  142. key_padding_mask = (attention_mask == 0) if (
  143. not self.autogressive and attention_mask is not None) else None
  144. x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model]
  145. x = x + self.positional_embedding
  146. x = x.permute(1, 0, 2) # NLD -> LND
  147. for block in self.resblocks:
  148. x = block(x, key_padding_mask)
  149. x = x.permute(1, 0, 2) # LND -> NLD
  150. x = self.ln_final(x)
  151. return {'last_hidden_state': x}
  152. class LayerNorm(nn.Module):
  153. def __init__(self, hidden_size, eps=1e-12):
  154. """Construct a layernorm module in the TF style (epsilon inside the
  155. square root)."""
  156. super(LayerNorm, self).__init__()
  157. self.weight = nn.Parameter(torch.ones(hidden_size))
  158. self.bias = nn.Parameter(torch.zeros(hidden_size))
  159. self.variance_epsilon = eps
  160. def forward(self, x):
  161. pdtype = x.dtype
  162. x = x.float()
  163. u = x.mean(-1, keepdim=True)
  164. s = (x - u).pow(2).mean(-1, keepdim=True)
  165. x = (x - u) / torch.sqrt(s + self.variance_epsilon)
  166. return self.weight * x.to(pdtype) + self.bias
  167. class QuickGELU(nn.Module):
  168. def forward(self, x: torch.Tensor):
  169. return x * torch.sigmoid(1.702 * x)
  170. class ResidualAttentionBlock(nn.Module):
  171. def __init__(self,
  172. d_model: int,
  173. n_head: int,
  174. attn_mask: torch.Tensor = None,
  175. drop_path: float = 0.0):
  176. super().__init__()
  177. self.attn = nn.MultiheadAttention(d_model, n_head)
  178. self.ln_1 = LayerNorm(d_model)
  179. self.mlp = nn.Sequential(
  180. OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
  181. ('gelu', QuickGELU()),
  182. ('c_proj', nn.Linear(d_model * 4, d_model))]))
  183. self.ln_2 = LayerNorm(d_model)
  184. self.attn_mask = attn_mask
  185. self.drop_path = DropPath(
  186. drop_path) if drop_path > 0. else nn.Identity()
  187. def attention(self,
  188. x: torch.Tensor,
  189. key_padding_mask: torch.Tensor = None):
  190. self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
  191. if self.attn_mask is not None else None
  192. return self.attn(
  193. x,
  194. x,
  195. x,
  196. key_padding_mask=key_padding_mask,
  197. need_weights=False,
  198. attn_mask=self.attn_mask)[0]
  199. def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
  200. x = x + self.drop_path(
  201. self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
  202. x = x + self.drop_path(self.mlp(self.ln_2(x)))
  203. return x