123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- import os
- from collections import OrderedDict
- import torch
- from mmcv.cnn.bricks import DropPath
- from torch import nn
- from transformers import CLIPTokenizer
- from .utils import get_prompt_templates
- # modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/language/vlpencoder.py # noqa
- class LanguageEncoder(nn.Module):
- def __init__(
- self,
- tokenizer='openai/clip-vit-base-patch32',
- dim_lang=512,
- dim_projection=512,
- ):
- super().__init__()
- os.environ['TOKENIZERS_PARALLELISM'] = 'true'
- self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
- self.tokenizer.add_special_tokens(
- {'cls_token': self.tokenizer.eos_token})
- max_token_num = self.tokenizer.model_max_length
- self.lang_encoder = Transformer(max_token_num,
- self.tokenizer.vocab_size, dim_lang)
- self.lang_proj = nn.Parameter(torch.empty(dim_lang, dim_projection))
- self.max_token_num = max_token_num
- self.logit_scale = nn.Parameter(torch.ones([]))
- @torch.no_grad()
- def get_mean_embeds(self, class_names, name='default'):
- def extract_mean_emb(txts):
- tokens = self.tokenizer(
- txts,
- padding='max_length',
- truncation=True,
- max_length=self.max_token_num,
- return_tensors='pt')
- clss_embedding, _ = self.forward_language(
- (tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()),
- norm=True,
- with_token_embed=False)
- clss_embedding = clss_embedding.mean(dim=0)
- clss_embedding /= clss_embedding.norm()
- return clss_embedding
- templates = get_prompt_templates()
- clss_embeddings = []
- for clss in class_names:
- txts = [
- template.format(
- clss.replace('-other',
- '').replace('-merged',
- '').replace('-stuff', ''))
- for template in templates
- ]
- clss_embeddings.append(extract_mean_emb(txts))
- text_emb = torch.stack(clss_embeddings, dim=0)
- setattr(self, '{}_text_embeddings'.format(name), text_emb)
- def get_text_embeds(self, txts, name='grounding', norm=False):
- tokens = self.tokenizer(
- txts,
- padding='max_length',
- truncation=True,
- max_length=self.max_token_num,
- return_tensors='pt')
- tokens = {key: value.cuda() for key, value in tokens.items()}
- class_emb, token_emb = self.forward_language(
- (tokens['input_ids'], tokens['attention_mask']), norm=norm)
- ret = {
- 'tokens': tokens,
- 'token_emb': token_emb,
- 'class_emb': class_emb,
- }
- setattr(self, '{}_token_embeddings'.format(name), ret)
- return ret
- def get_sot_token(self, device):
- # 49406: CLIP SOT token <|startoftext|>
- # 77: CLIP context_length
- return torch.tensor([[49406] * 77], device=device)
- def compute_similarity(self, v_emb, name='default'):
- v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
- t_emb = getattr(self, '{}_text_embeddings'.format(name))
- output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(
- 1, 2)
- return output
- def forward_language(self,
- texts,
- norm=False,
- with_token_embed=True,
- with_cls_embed=True):
- x = self.lang_encoder(*texts)
- hidden_x = x['last_hidden_state']
- class_embed = None
- if with_cls_embed:
- class_embed = hidden_x[torch.arange(hidden_x.size(0)),
- texts[0].argmax(dim=-1)]
- class_embed = class_embed @ self.lang_proj
- if norm:
- class_embed = class_embed / (
- class_embed.norm(dim=-1, keepdim=True) + 1e-7)
- hidden_embed = None
- if with_token_embed:
- hidden_embed = hidden_x @ self.lang_proj
- if norm:
- hidden_embed = hidden_embed / (
- hidden_embed.norm(dim=-1, keepdim=True) + 1e-7)
- return class_embed, hidden_embed
- class Transformer(nn.Module):
- def __init__(self,
- context_length,
- vocab_size,
- width,
- layers: int = 12,
- heads: int = 8,
- drop_path: float = 0.0,
- autogressive: bool = True):
- super().__init__()
- self.token_embedding = nn.Embedding(vocab_size, width)
- self.context_length = context_length
- self.positional_embedding = nn.Parameter(
- torch.empty(self.context_length, width))
- self.width = width
- self.layers = layers
- self.autogressive = autogressive
- attn_mask = self.build_attention_mask() if autogressive else None
- dpr = [x.item() for x in torch.linspace(0, drop_path, layers)
- ] # stochastic depth decay rule
- self.resblocks = nn.ModuleList([
- ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
- for i in range(layers)
- ])
- self.ln_final = LayerNorm(width)
- @property
- def dim_out(self):
- return self.width
- def build_attention_mask(self):
- # lazily create causal attention mask,
- # with full attention between the vision tokens
- # pytorch uses additive attention mask; fill with -inf
- mask = torch.empty(self.context_length, self.context_length)
- mask.fill_(float('-inf'))
- mask.triu_(1) # zero out the lower diagonal
- return mask
- def forward(self, input_ids, attention_mask=None):
- key_padding_mask = (attention_mask == 0) if (
- not self.autogressive and attention_mask is not None) else None
- x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model]
- x = x + self.positional_embedding
- x = x.permute(1, 0, 2) # NLD -> LND
- for block in self.resblocks:
- x = block(x, key_padding_mask)
- x = x.permute(1, 0, 2) # LND -> NLD
- x = self.ln_final(x)
- return {'last_hidden_state': x}
- class LayerNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-12):
- """Construct a layernorm module in the TF style (epsilon inside the
- square root)."""
- super(LayerNorm, self).__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.bias = nn.Parameter(torch.zeros(hidden_size))
- self.variance_epsilon = eps
- def forward(self, x):
- pdtype = x.dtype
- x = x.float()
- u = x.mean(-1, keepdim=True)
- s = (x - u).pow(2).mean(-1, keepdim=True)
- x = (x - u) / torch.sqrt(s + self.variance_epsilon)
- return self.weight * x.to(pdtype) + self.bias
- class QuickGELU(nn.Module):
- def forward(self, x: torch.Tensor):
- return x * torch.sigmoid(1.702 * x)
- class ResidualAttentionBlock(nn.Module):
- def __init__(self,
- d_model: int,
- n_head: int,
- attn_mask: torch.Tensor = None,
- drop_path: float = 0.0):
- super().__init__()
- self.attn = nn.MultiheadAttention(d_model, n_head)
- self.ln_1 = LayerNorm(d_model)
- self.mlp = nn.Sequential(
- OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
- ('gelu', QuickGELU()),
- ('c_proj', nn.Linear(d_model * 4, d_model))]))
- self.ln_2 = LayerNorm(d_model)
- self.attn_mask = attn_mask
- self.drop_path = DropPath(
- drop_path) if drop_path > 0. else nn.Identity()
- def attention(self,
- x: torch.Tensor,
- key_padding_mask: torch.Tensor = None):
- self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
- if self.attn_mask is not None else None
- return self.attn(
- x,
- x,
- x,
- key_padding_mask=key_padding_mask,
- need_weights=False,
- attn_mask=self.attn_mask)[0]
- def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
- x = x + self.drop_path(
- self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
- x = x + self.drop_path(self.mlp(self.ln_2(x)))
- return x
|