vlfuse_helper.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import torch.utils.checkpoint as checkpoint
  7. from mmcv.cnn.bricks import DropPath
  8. from torch import Tensor
  9. try:
  10. from transformers import BertPreTrainedModel
  11. from transformers.modeling_utils import apply_chunking_to_forward
  12. from transformers.models.bert.modeling_bert import \
  13. BertAttention as HFBertAttention
  14. from transformers.models.bert.modeling_bert import \
  15. BertIntermediate as HFBertIntermediate
  16. from transformers.models.bert.modeling_bert import \
  17. BertOutput as HFBertOutput
  18. except ImportError:
  19. BertPreTrainedModel = object
  20. apply_chunking_to_forward = None
  21. HFBertAttention = object
  22. HFBertIntermediate = object
  23. HFBertOutput = object
  24. MAX_CLAMP_VALUE = 50000
  25. def permute_and_flatten(layer, N, A, C, H, W):
  26. layer = layer.view(N, A, C, H, W)
  27. layer = layer.permute(0, 3, 4, 1, 2)
  28. layer = layer.reshape(N, -1, C)
  29. return layer
  30. def clamp_values(vector):
  31. vector = torch.clamp(vector, min=-MAX_CLAMP_VALUE, max=MAX_CLAMP_VALUE)
  32. return vector
  33. class BiMultiHeadAttention(nn.Module):
  34. """Bidirectional fusion Multi-Head Attention layer."""
  35. def __init__(self,
  36. v_dim: int,
  37. l_dim: int,
  38. embed_dim: int,
  39. num_heads: int,
  40. dropout: float = 0.1):
  41. super(BiMultiHeadAttention, self).__init__()
  42. self.embed_dim = embed_dim
  43. self.num_heads = num_heads
  44. self.head_dim = embed_dim // num_heads
  45. self.v_dim = v_dim
  46. self.l_dim = l_dim
  47. assert (
  48. self.head_dim * self.num_heads == self.embed_dim
  49. ), 'embed_dim must be divisible by num_heads ' \
  50. f'(got `embed_dim`: {self.embed_dim} ' \
  51. f'and `num_heads`: {self.num_heads}).'
  52. self.scale = self.head_dim**(-0.5)
  53. self.dropout = dropout
  54. self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
  55. self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
  56. self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
  57. self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
  58. self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
  59. self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
  60. self.stable_softmax_2d = False
  61. self.clamp_min_for_underflow = True
  62. self.clamp_max_for_overflow = True
  63. self._reset_parameters()
  64. def _shape(self, tensor: Tensor, seq_len: int, bsz: int):
  65. return tensor.view(bsz, seq_len, self.num_heads,
  66. self.head_dim).transpose(1, 2).contiguous()
  67. def _reset_parameters(self):
  68. nn.init.xavier_uniform_(self.v_proj.weight)
  69. self.v_proj.bias.data.fill_(0)
  70. nn.init.xavier_uniform_(self.l_proj.weight)
  71. self.l_proj.bias.data.fill_(0)
  72. nn.init.xavier_uniform_(self.values_v_proj.weight)
  73. self.values_v_proj.bias.data.fill_(0)
  74. nn.init.xavier_uniform_(self.values_l_proj.weight)
  75. self.values_l_proj.bias.data.fill_(0)
  76. nn.init.xavier_uniform_(self.out_v_proj.weight)
  77. self.out_v_proj.bias.data.fill_(0)
  78. nn.init.xavier_uniform_(self.out_l_proj.weight)
  79. self.out_l_proj.bias.data.fill_(0)
  80. def forward(self, vision: Tensor, lang: Tensor, attention_mask_l=None):
  81. bsz, tgt_len, _ = vision.size()
  82. query_states = self.v_proj(vision) * self.scale
  83. key_states = self._shape(self.l_proj(lang), -1, bsz)
  84. value_v_states = self._shape(self.values_v_proj(vision), -1, bsz)
  85. value_l_states = self._shape(self.values_l_proj(lang), -1, bsz)
  86. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  87. query_states = self._shape(query_states, tgt_len,
  88. bsz).view(*proj_shape)
  89. key_states = key_states.view(*proj_shape)
  90. value_v_states = value_v_states.view(*proj_shape)
  91. value_l_states = value_l_states.view(*proj_shape)
  92. src_len = key_states.size(1)
  93. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  94. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  95. raise ValueError(
  96. f'Attention weights should be of '
  97. f'size {(bsz * self.num_heads, tgt_len, src_len)}, '
  98. f'but is {attn_weights.size()}')
  99. if self.stable_softmax_2d:
  100. attn_weights = attn_weights - attn_weights.max()
  101. if self.clamp_min_for_underflow:
  102. # Do not increase -50000, data type half has quite limited range
  103. attn_weights = torch.clamp(attn_weights, min=-MAX_CLAMP_VALUE)
  104. if self.clamp_max_for_overflow:
  105. # Do not increase 50000, data type half has quite limited range
  106. attn_weights = torch.clamp(attn_weights, max=MAX_CLAMP_VALUE)
  107. attn_weights_T = attn_weights.transpose(1, 2)
  108. attn_weights_l = (
  109. attn_weights_T -
  110. torch.max(attn_weights_T, dim=-1, keepdim=True)[0])
  111. if self.clamp_min_for_underflow:
  112. # Do not increase -50000, data type half has quite limited range
  113. attn_weights_l = torch.clamp(attn_weights_l, min=-MAX_CLAMP_VALUE)
  114. if self.clamp_max_for_overflow:
  115. # Do not increase 50000, data type half has quite limited range
  116. attn_weights_l = torch.clamp(attn_weights_l, max=MAX_CLAMP_VALUE)
  117. attn_weights_l = attn_weights_l.softmax(dim=-1)
  118. if attention_mask_l is not None:
  119. assert (attention_mask_l.dim() == 2)
  120. attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1)
  121. attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len)
  122. attention_mask = attention_mask.masked_fill(
  123. attention_mask == 0, -9e15)
  124. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  125. raise ValueError('Attention mask should be of '
  126. f'size {(bsz, 1, tgt_len, src_len)}')
  127. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
  128. src_len) + attention_mask
  129. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
  130. src_len)
  131. attn_weights_v = nn.functional.softmax(attn_weights, dim=-1)
  132. attn_probs_v = F.dropout(
  133. attn_weights_v, p=self.dropout, training=self.training)
  134. attn_probs_l = F.dropout(
  135. attn_weights_l, p=self.dropout, training=self.training)
  136. attn_output_v = torch.bmm(attn_probs_v, value_l_states)
  137. attn_output_l = torch.bmm(attn_probs_l, value_v_states)
  138. if attn_output_v.size() != (bsz * self.num_heads, tgt_len,
  139. self.head_dim):
  140. raise ValueError(
  141. '`attn_output_v` should be of '
  142. f'size {(bsz, self.num_heads, tgt_len, self.head_dim)}, '
  143. f'but is {attn_output_v.size()}')
  144. if attn_output_l.size() != (bsz * self.num_heads, src_len,
  145. self.head_dim):
  146. raise ValueError(
  147. '`attn_output_l` should be of size '
  148. f'{(bsz, self.num_heads, src_len, self.head_dim)}, '
  149. f'but is {attn_output_l.size()}')
  150. attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len,
  151. self.head_dim)
  152. attn_output_v = attn_output_v.transpose(1, 2)
  153. attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
  154. attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len,
  155. self.head_dim)
  156. attn_output_l = attn_output_l.transpose(1, 2)
  157. attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
  158. attn_output_v = self.out_v_proj(attn_output_v)
  159. attn_output_l = self.out_l_proj(attn_output_l)
  160. return attn_output_v, attn_output_l
  161. class BiAttentionBlock(nn.Module):
  162. """BiAttentionBlock Module:
  163. First, multi-level visual features are concat; Then the concat visual
  164. feature and lang feature are fused by attention; Finally the newly visual
  165. feature are split into multi levels.
  166. """
  167. def __init__(self,
  168. v_dim: int,
  169. l_dim: int,
  170. embed_dim: int,
  171. num_heads: int,
  172. dropout: float = 0.1,
  173. drop_path: float = .0,
  174. init_values: float = 1e-4):
  175. super().__init__()
  176. # pre layer norm
  177. self.layer_norm_v = nn.LayerNorm(v_dim)
  178. self.layer_norm_l = nn.LayerNorm(l_dim)
  179. self.attn = BiMultiHeadAttention(
  180. v_dim=v_dim,
  181. l_dim=l_dim,
  182. embed_dim=embed_dim,
  183. num_heads=num_heads,
  184. dropout=dropout)
  185. # add layer scale for training stability
  186. self.drop_path = DropPath(
  187. drop_path) if drop_path > 0. else nn.Identity()
  188. self.gamma_v = nn.Parameter(
  189. init_values * torch.ones(v_dim), requires_grad=True)
  190. self.gamma_l = nn.Parameter(
  191. init_values * torch.ones(l_dim), requires_grad=True)
  192. def forward(self,
  193. visual_features: list,
  194. lang_feature: Tensor,
  195. attention_mask_l=None):
  196. size_per_level, visual_features_flatten = [], []
  197. for i, feat_per_level in enumerate(visual_features):
  198. bs, c, h, w = feat_per_level.shape
  199. size_per_level.append([h, w])
  200. feat = permute_and_flatten(feat_per_level, bs, -1, c, h, w)
  201. visual_features_flatten.append(feat)
  202. visual_features_flatten = torch.cat(visual_features_flatten, dim=1)
  203. new_v, new_lang_feature = self.single_attention_call(
  204. visual_features_flatten,
  205. lang_feature,
  206. attention_mask_l=attention_mask_l)
  207. # [bs, N, C] -> [bs, C, N]
  208. new_v = new_v.transpose(1, 2).contiguous()
  209. start = 0
  210. fusion_visual_features = []
  211. for (h, w) in size_per_level:
  212. new_v_per_level = new_v[:, :,
  213. start:start + h * w].view(bs, -1, h,
  214. w).contiguous()
  215. fusion_visual_features.append(new_v_per_level)
  216. start += h * w
  217. return fusion_visual_features, new_lang_feature
  218. def single_attention_call(self, visual, lang, attention_mask_l=None):
  219. visual = self.layer_norm_v(visual)
  220. lang = self.layer_norm_l(lang)
  221. delta_v, delta_l = self.attn(
  222. visual, lang, attention_mask_l=attention_mask_l)
  223. # visual, lang = visual + delta_v, l + delta_l
  224. visual = visual + self.drop_path(self.gamma_v * delta_v)
  225. lang = lang + self.drop_path(self.gamma_l * delta_l)
  226. return visual, lang
  227. class VLFuse(nn.Module):
  228. """Early Fusion Module."""
  229. def __init__(self,
  230. v_dim: int = 256,
  231. l_dim: int = 768,
  232. embed_dim: int = 2048,
  233. num_heads: int = 8,
  234. dropout: float = 0.1,
  235. drop_path: float = 0.0,
  236. use_checkpoint: bool = False):
  237. super().__init__()
  238. # bi-direction (text->image, image->text)
  239. self.use_checkpoint = use_checkpoint
  240. self.b_attn = BiAttentionBlock(
  241. v_dim=v_dim,
  242. l_dim=l_dim,
  243. embed_dim=embed_dim,
  244. num_heads=num_heads,
  245. dropout=dropout,
  246. drop_path=drop_path,
  247. init_values=1.0 / 6.0)
  248. def forward(self, x):
  249. visual_features = x['visual']
  250. language_dict_features = x['lang']
  251. if self.use_checkpoint:
  252. fused_visual_features, language_features = checkpoint.checkpoint(
  253. self.b_attn, visual_features, language_dict_features['hidden'],
  254. language_dict_features['masks'])
  255. else:
  256. fused_visual_features, language_features = self.b_attn(
  257. visual_features, language_dict_features['hidden'],
  258. language_dict_features['masks'])
  259. language_dict_features['hidden'] = language_features
  260. fused_language_dict_features = language_dict_features
  261. features_dict = {
  262. 'visual': fused_visual_features,
  263. 'lang': fused_language_dict_features
  264. }
  265. return features_dict
  266. class BertEncoderLayer(BertPreTrainedModel):
  267. """Modified from transformers.models.bert.modeling_bert.BertLayer."""
  268. def __init__(self,
  269. config,
  270. clamp_min_for_underflow: bool = False,
  271. clamp_max_for_overflow: bool = False):
  272. super().__init__(config)
  273. self.config = config
  274. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  275. self.seq_len_dim = 1
  276. self.attention = BertAttention(config, clamp_min_for_underflow,
  277. clamp_max_for_overflow)
  278. self.intermediate = BertIntermediate(config)
  279. self.output = BertOutput(config)
  280. def forward(self, inputs):
  281. language_dict_features = inputs['lang']
  282. hidden_states = language_dict_features['hidden']
  283. attention_mask = language_dict_features['masks']
  284. device = hidden_states.device
  285. input_shape = hidden_states.size()[:-1]
  286. # We can provide a self-attention mask of dimensions
  287. # [batch_size, from_seq_length, to_seq_length]
  288. # ourselves in which case we just need to make it
  289. # broadcastable to all heads.
  290. extended_attention_mask = self.get_extended_attention_mask(
  291. attention_mask, input_shape, device)
  292. self_attention_outputs = self.attention(
  293. hidden_states,
  294. extended_attention_mask,
  295. None,
  296. output_attentions=False,
  297. past_key_value=None,
  298. )
  299. attention_output = self_attention_outputs[0]
  300. outputs = self_attention_outputs[
  301. 1:] # add self attentions if we output attention weights
  302. layer_output = apply_chunking_to_forward(self.feed_forward_chunk,
  303. self.chunk_size_feed_forward,
  304. self.seq_len_dim,
  305. attention_output)
  306. outputs = (layer_output, ) + outputs
  307. hidden_states = outputs[0]
  308. language_dict_features['hidden'] = hidden_states
  309. features_dict = {
  310. 'visual': inputs['visual'],
  311. 'lang': language_dict_features
  312. }
  313. return features_dict
  314. def feed_forward_chunk(self, attention_output):
  315. intermediate_output = self.intermediate(attention_output)
  316. layer_output = self.output(intermediate_output, attention_output)
  317. return layer_output
  318. # The following code is the same as the Huggingface code,
  319. # with the only difference being the additional clamp operation.
  320. class BertSelfAttention(nn.Module):
  321. """BERT self-attention layer from Huggingface transformers.
  322. Compared to the BertSelfAttention of Huggingface, only add the clamp.
  323. """
  324. def __init__(self,
  325. config,
  326. clamp_min_for_underflow: bool = False,
  327. clamp_max_for_overflow: bool = False):
  328. super().__init__()
  329. if config.hidden_size % config.num_attention_heads != 0 and \
  330. not hasattr(config, 'embedding_size'):
  331. raise ValueError(f'The hidden size ({config.hidden_size}) is '
  332. 'not a multiple of the number of attention '
  333. f'heads ({config.num_attention_heads})')
  334. self.num_attention_heads = config.num_attention_heads
  335. self.attention_head_size = int(config.hidden_size /
  336. config.num_attention_heads)
  337. self.all_head_size = self.num_attention_heads * \
  338. self.attention_head_size
  339. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  340. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  341. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  342. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  343. self.position_embedding_type = getattr(config,
  344. 'position_embedding_type',
  345. 'absolute')
  346. if self.position_embedding_type == 'relative_key' or \
  347. self.position_embedding_type == 'relative_key_query':
  348. self.max_position_embeddings = config.max_position_embeddings
  349. self.distance_embedding = nn.Embedding(
  350. 2 * config.max_position_embeddings - 1,
  351. self.attention_head_size)
  352. self.clamp_min_for_underflow = clamp_min_for_underflow
  353. self.clamp_max_for_overflow = clamp_max_for_overflow
  354. self.is_decoder = config.is_decoder
  355. def transpose_for_scores(self, x):
  356. new_x_shape = x.size()[:-1] + (self.num_attention_heads,
  357. self.attention_head_size)
  358. x = x.view(*new_x_shape)
  359. return x.permute(0, 2, 1, 3)
  360. def forward(
  361. self,
  362. hidden_states,
  363. attention_mask=None,
  364. head_mask=None,
  365. encoder_hidden_states=None,
  366. encoder_attention_mask=None,
  367. past_key_value=None,
  368. output_attentions=False,
  369. ):
  370. mixed_query_layer = self.query(hidden_states)
  371. # If this is instantiated as a cross-attention module, the keys
  372. # and values come from an encoder; the attention mask needs to be
  373. # such that the encoder's padding tokens are not attended to.
  374. is_cross_attention = encoder_hidden_states is not None
  375. if is_cross_attention and past_key_value is not None:
  376. # reuse k,v, cross_attentions
  377. key_layer = past_key_value[0]
  378. value_layer = past_key_value[1]
  379. attention_mask = encoder_attention_mask
  380. elif is_cross_attention:
  381. key_layer = self.transpose_for_scores(
  382. self.key(encoder_hidden_states))
  383. value_layer = self.transpose_for_scores(
  384. self.value(encoder_hidden_states))
  385. attention_mask = encoder_attention_mask
  386. elif past_key_value is not None:
  387. key_layer = self.transpose_for_scores(self.key(hidden_states))
  388. value_layer = self.transpose_for_scores(self.value(hidden_states))
  389. key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
  390. value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
  391. else:
  392. key_layer = self.transpose_for_scores(self.key(hidden_states))
  393. value_layer = self.transpose_for_scores(self.value(hidden_states))
  394. query_layer = self.transpose_for_scores(mixed_query_layer)
  395. if self.is_decoder:
  396. past_key_value = (key_layer, value_layer)
  397. # Take the dot product between "query" and "key"
  398. # to get the raw attention scores.
  399. attention_scores = torch.matmul(query_layer,
  400. key_layer.transpose(-1, -2))
  401. if self.position_embedding_type == 'relative_key' or \
  402. self.position_embedding_type == 'relative_key_query':
  403. seq_length = hidden_states.size()[1]
  404. position_ids_l = torch.arange(
  405. seq_length, dtype=torch.long,
  406. device=hidden_states.device).view(-1, 1)
  407. position_ids_r = torch.arange(
  408. seq_length, dtype=torch.long,
  409. device=hidden_states.device).view(1, -1)
  410. distance = position_ids_l - position_ids_r
  411. positional_embedding = self.distance_embedding(
  412. distance + self.max_position_embeddings - 1)
  413. positional_embedding = positional_embedding.to(
  414. dtype=query_layer.dtype) # fp16 compatibility
  415. if self.position_embedding_type == 'relative_key':
  416. relative_position_scores = torch.einsum(
  417. 'bhld,lrd->bhlr', query_layer, positional_embedding)
  418. attention_scores = attention_scores + relative_position_scores
  419. elif self.position_embedding_type == 'relative_key_query':
  420. relative_position_scores_query = torch.einsum(
  421. 'bhld,lrd->bhlr', query_layer, positional_embedding)
  422. relative_position_scores_key = torch.einsum(
  423. 'bhrd,lrd->bhlr', key_layer, positional_embedding)
  424. attention_scores = attention_scores + \
  425. relative_position_scores_query + \
  426. relative_position_scores_key
  427. attention_scores = attention_scores / math.sqrt(
  428. self.attention_head_size)
  429. if self.clamp_min_for_underflow:
  430. attention_scores = torch.clamp(
  431. attention_scores, min=-MAX_CLAMP_VALUE
  432. ) # Do not increase -50000, data type half has quite limited range
  433. if self.clamp_max_for_overflow:
  434. attention_scores = torch.clamp(
  435. attention_scores, max=MAX_CLAMP_VALUE
  436. ) # Do not increase 50000, data type half has quite limited range
  437. if attention_mask is not None:
  438. # Apply the attention mask is
  439. # (precomputed for all layers in BertModel forward() function)
  440. attention_scores = attention_scores + attention_mask
  441. # Normalize the attention scores to probabilities.
  442. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  443. # This is actually dropping out entire tokens to attend to, which might
  444. # seem a bit unusual, but is taken from the original Transformer paper.
  445. attention_probs = self.dropout(attention_probs)
  446. # Mask heads if we want to
  447. if head_mask is not None:
  448. attention_probs = attention_probs * head_mask
  449. context_layer = torch.matmul(attention_probs, value_layer)
  450. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  451. new_context_layer_shape = context_layer.size()[:-2] + (
  452. self.all_head_size, )
  453. context_layer = context_layer.view(*new_context_layer_shape)
  454. outputs = (context_layer,
  455. attention_probs) if output_attentions else (context_layer, )
  456. if self.is_decoder:
  457. outputs = outputs + (past_key_value, )
  458. return outputs
  459. class BertAttention(HFBertAttention):
  460. """BertAttention is made up of self-attention and intermediate+output.
  461. Compared to the BertAttention of Huggingface, only add the clamp.
  462. """
  463. def __init__(self,
  464. config,
  465. clamp_min_for_underflow: bool = False,
  466. clamp_max_for_overflow: bool = False):
  467. super().__init__(config)
  468. self.self = BertSelfAttention(config, clamp_min_for_underflow,
  469. clamp_max_for_overflow)
  470. class BertIntermediate(HFBertIntermediate):
  471. def forward(self, hidden_states):
  472. hidden_states = self.dense(hidden_states)
  473. hidden_states = clamp_values(hidden_states)
  474. hidden_states = self.intermediate_act_fn(hidden_states)
  475. hidden_states = clamp_values(hidden_states)
  476. return hidden_states
  477. class BertOutput(HFBertOutput):
  478. def forward(self, hidden_states, input_tensor):
  479. hidden_states = self.dense(hidden_states)
  480. hidden_states = self.dropout(hidden_states)
  481. hidden_states = clamp_values(hidden_states)
  482. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  483. hidden_states = clamp_values(hidden_states)
  484. return hidden_states