bert.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from collections import OrderedDict
  3. from typing import Sequence
  4. import torch
  5. from mmengine.model import BaseModel
  6. from torch import nn
  7. try:
  8. from transformers import AutoTokenizer, BertConfig
  9. from transformers import BertModel as HFBertModel
  10. except ImportError:
  11. AutoTokenizer = None
  12. HFBertModel = None
  13. from mmdet.registry import MODELS
  14. @MODELS.register_module()
  15. class BertModel(BaseModel):
  16. """BERT model for language embedding only encoder.
  17. Args:
  18. name (str): name of the pretrained BERT model from HuggingFace.
  19. Defaults to bert-base-uncased.
  20. max_tokens (int): maximum number of tokens to be used for BERT.
  21. Defaults to 256.
  22. pad_to_max (bool): whether to pad the tokens to max_tokens.
  23. Defaults to True.
  24. num_layers_of_embedded (int): number of layers of the embedded model.
  25. Defaults to 1.
  26. use_checkpoint (bool): whether to use gradient checkpointing.
  27. Defaults to False.
  28. """
  29. def __init__(self,
  30. name: str = 'bert-base-uncased',
  31. max_tokens: int = 256,
  32. pad_to_max: bool = True,
  33. num_layers_of_embedded: int = 1,
  34. use_checkpoint: bool = False,
  35. **kwargs) -> None:
  36. super().__init__(**kwargs)
  37. self.max_tokens = max_tokens
  38. self.pad_to_max = pad_to_max
  39. if AutoTokenizer is None:
  40. raise RuntimeError(
  41. 'transformers is not installed, please install it by: '
  42. 'pip install transformers.')
  43. self.tokenizer = AutoTokenizer.from_pretrained(name)
  44. self.language_backbone = nn.Sequential(
  45. OrderedDict([('body',
  46. BertEncoder(
  47. name,
  48. num_layers_of_embedded=num_layers_of_embedded,
  49. use_checkpoint=use_checkpoint))]))
  50. def forward(self, captions: Sequence[str], **kwargs) -> dict:
  51. """Forward function."""
  52. device = next(self.language_backbone.parameters()).device
  53. tokenized = self.tokenizer.batch_encode_plus(
  54. captions,
  55. max_length=self.max_tokens,
  56. padding='max_length' if self.pad_to_max else 'longest',
  57. return_special_tokens_mask=True,
  58. return_tensors='pt',
  59. truncation=True).to(device)
  60. tokenizer_input = {
  61. 'input_ids': tokenized.input_ids,
  62. 'attention_mask': tokenized.attention_mask
  63. }
  64. language_dict_features = self.language_backbone(tokenizer_input)
  65. return language_dict_features
  66. class BertEncoder(nn.Module):
  67. """BERT encoder for language embedding.
  68. Args:
  69. name (str): name of the pretrained BERT model from HuggingFace.
  70. Defaults to bert-base-uncased.
  71. num_layers_of_embedded (int): number of layers of the embedded model.
  72. Defaults to 1.
  73. use_checkpoint (bool): whether to use gradient checkpointing.
  74. Defaults to False.
  75. """
  76. def __init__(self,
  77. name: str,
  78. num_layers_of_embedded: int = 1,
  79. use_checkpoint: bool = False):
  80. super().__init__()
  81. if BertConfig is None:
  82. raise RuntimeError(
  83. 'transformers is not installed, please install it by: '
  84. 'pip install transformers.')
  85. config = BertConfig.from_pretrained(name)
  86. config.gradient_checkpointing = use_checkpoint
  87. # only encoder
  88. self.model = HFBertModel.from_pretrained(
  89. name, add_pooling_layer=False, config=config)
  90. self.language_dim = config.hidden_size
  91. self.num_layers_of_embedded = num_layers_of_embedded
  92. def forward(self, x) -> dict:
  93. mask = x['attention_mask']
  94. outputs = self.model(
  95. input_ids=x['input_ids'],
  96. attention_mask=mask,
  97. output_hidden_states=True,
  98. )
  99. # outputs has 13 layers, 1 input layer and 12 hidden layers
  100. encoded_layers = outputs.hidden_states[1:]
  101. features = torch.stack(encoded_layers[-self.num_layers_of_embedded:],
  102. 1).mean(1)
  103. # language embedding has shape [len(phrase), seq_len, language_dim]
  104. features = features / self.num_layers_of_embedded
  105. embedded = features * mask.unsqueeze(-1).float()
  106. results = {
  107. 'embedded': embedded,
  108. 'masks': mask,
  109. 'hidden': encoded_layers[-1]
  110. }
  111. return results