123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from collections import OrderedDict
- from typing import Sequence
- import torch
- from mmengine.model import BaseModel
- from torch import nn
- try:
- from transformers import AutoTokenizer, BertConfig
- from transformers import BertModel as HFBertModel
- except ImportError:
- AutoTokenizer = None
- HFBertModel = None
- from mmdet.registry import MODELS
- @MODELS.register_module()
- class BertModel(BaseModel):
- """BERT model for language embedding only encoder.
- Args:
- name (str): name of the pretrained BERT model from HuggingFace.
- Defaults to bert-base-uncased.
- max_tokens (int): maximum number of tokens to be used for BERT.
- Defaults to 256.
- pad_to_max (bool): whether to pad the tokens to max_tokens.
- Defaults to True.
- num_layers_of_embedded (int): number of layers of the embedded model.
- Defaults to 1.
- use_checkpoint (bool): whether to use gradient checkpointing.
- Defaults to False.
- """
- def __init__(self,
- name: str = 'bert-base-uncased',
- max_tokens: int = 256,
- pad_to_max: bool = True,
- num_layers_of_embedded: int = 1,
- use_checkpoint: bool = False,
- **kwargs) -> None:
- super().__init__(**kwargs)
- self.max_tokens = max_tokens
- self.pad_to_max = pad_to_max
- if AutoTokenizer is None:
- raise RuntimeError(
- 'transformers is not installed, please install it by: '
- 'pip install transformers.')
- self.tokenizer = AutoTokenizer.from_pretrained(name)
- self.language_backbone = nn.Sequential(
- OrderedDict([('body',
- BertEncoder(
- name,
- num_layers_of_embedded=num_layers_of_embedded,
- use_checkpoint=use_checkpoint))]))
- def forward(self, captions: Sequence[str], **kwargs) -> dict:
- """Forward function."""
- device = next(self.language_backbone.parameters()).device
- tokenized = self.tokenizer.batch_encode_plus(
- captions,
- max_length=self.max_tokens,
- padding='max_length' if self.pad_to_max else 'longest',
- return_special_tokens_mask=True,
- return_tensors='pt',
- truncation=True).to(device)
- tokenizer_input = {
- 'input_ids': tokenized.input_ids,
- 'attention_mask': tokenized.attention_mask
- }
- language_dict_features = self.language_backbone(tokenizer_input)
- return language_dict_features
- class BertEncoder(nn.Module):
- """BERT encoder for language embedding.
- Args:
- name (str): name of the pretrained BERT model from HuggingFace.
- Defaults to bert-base-uncased.
- num_layers_of_embedded (int): number of layers of the embedded model.
- Defaults to 1.
- use_checkpoint (bool): whether to use gradient checkpointing.
- Defaults to False.
- """
- def __init__(self,
- name: str,
- num_layers_of_embedded: int = 1,
- use_checkpoint: bool = False):
- super().__init__()
- if BertConfig is None:
- raise RuntimeError(
- 'transformers is not installed, please install it by: '
- 'pip install transformers.')
- config = BertConfig.from_pretrained(name)
- config.gradient_checkpointing = use_checkpoint
- # only encoder
- self.model = HFBertModel.from_pretrained(
- name, add_pooling_layer=False, config=config)
- self.language_dim = config.hidden_size
- self.num_layers_of_embedded = num_layers_of_embedded
- def forward(self, x) -> dict:
- mask = x['attention_mask']
- outputs = self.model(
- input_ids=x['input_ids'],
- attention_mask=mask,
- output_hidden_states=True,
- )
- # outputs has 13 layers, 1 input layer and 12 hidden layers
- encoded_layers = outputs.hidden_states[1:]
- features = torch.stack(encoded_layers[-self.num_layers_of_embedded:],
- 1).mean(1)
- # language embedding has shape [len(phrase), seq_len, language_dim]
- features = features / self.num_layers_of_embedded
- embedded = features * mask.unsqueeze(-1).float()
- results = {
- 'embedded': embedded,
- 'masks': mask,
- 'hidden': encoded_layers[-1]
- }
- return results
|