dino_layers.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. from typing import Tuple, Union
  4. import torch
  5. from mmengine.model import BaseModule
  6. from torch import Tensor, nn
  7. from mmdet.structures import SampleList
  8. from mmdet.structures.bbox import bbox_xyxy_to_cxcywh
  9. from mmdet.utils import OptConfigType
  10. from .deformable_detr_layers import DeformableDetrTransformerDecoder
  11. from .utils import MLP, coordinate_to_encoding, inverse_sigmoid
  12. class DinoTransformerDecoder(DeformableDetrTransformerDecoder):
  13. """Transformer decoder of DINO."""
  14. def _init_layers(self) -> None:
  15. """Initialize decoder layers."""
  16. super()._init_layers()
  17. self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims,
  18. self.embed_dims, 2)
  19. self.norm = nn.LayerNorm(self.embed_dims)
  20. def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor,
  21. self_attn_mask: Tensor, reference_points: Tensor,
  22. spatial_shapes: Tensor, level_start_index: Tensor,
  23. valid_ratios: Tensor, reg_branches: nn.ModuleList,
  24. **kwargs) -> Tuple[Tensor]:
  25. """Forward function of Transformer decoder.
  26. Args:
  27. query (Tensor): The input query, has shape (num_queries, bs, dim).
  28. value (Tensor): The input values, has shape (num_value, bs, dim).
  29. key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
  30. input. ByteTensor, has shape (num_queries, bs).
  31. self_attn_mask (Tensor): The attention mask to prevent information
  32. leakage from different denoising groups and matching parts, has
  33. shape (num_queries_total, num_queries_total). It is `None` when
  34. `self.training` is `False`.
  35. reference_points (Tensor): The initial reference, has shape
  36. (bs, num_queries, 4) with the last dimension arranged as
  37. (cx, cy, w, h).
  38. spatial_shapes (Tensor): Spatial shapes of features in all levels,
  39. has shape (num_levels, 2), last dimension represents (h, w).
  40. level_start_index (Tensor): The start index of each level.
  41. A tensor has shape (num_levels, ) and can be represented
  42. as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
  43. valid_ratios (Tensor): The ratios of the valid width and the valid
  44. height relative to the width and the height of features in all
  45. levels, has shape (bs, num_levels, 2).
  46. reg_branches: (obj:`nn.ModuleList`): Used for refining the
  47. regression results.
  48. Returns:
  49. tuple[Tensor]: Output queries and references of Transformer
  50. decoder
  51. - query (Tensor): Output embeddings of the last decoder, has
  52. shape (num_queries, bs, embed_dims) when `return_intermediate`
  53. is `False`. Otherwise, Intermediate output embeddings of all
  54. decoder layers, has shape (num_decoder_layers, num_queries, bs,
  55. embed_dims).
  56. - reference_points (Tensor): The reference of the last decoder
  57. layer, has shape (bs, num_queries, 4) when `return_intermediate`
  58. is `False`. Otherwise, Intermediate references of all decoder
  59. layers, has shape (num_decoder_layers, bs, num_queries, 4). The
  60. coordinates are arranged as (cx, cy, w, h)
  61. """
  62. intermediate = []
  63. intermediate_reference_points = [reference_points]
  64. for lid, layer in enumerate(self.layers):
  65. if reference_points.shape[-1] == 4:
  66. reference_points_input = \
  67. reference_points[:, :, None] * torch.cat(
  68. [valid_ratios, valid_ratios], -1)[:, None]
  69. else:
  70. assert reference_points.shape[-1] == 2
  71. reference_points_input = \
  72. reference_points[:, :, None] * valid_ratios[:, None]
  73. query_sine_embed = coordinate_to_encoding(
  74. reference_points_input[:, :, 0, :])
  75. query_pos = self.ref_point_head(query_sine_embed)
  76. query = layer(
  77. query,
  78. query_pos=query_pos,
  79. value=value,
  80. key_padding_mask=key_padding_mask,
  81. self_attn_mask=self_attn_mask,
  82. spatial_shapes=spatial_shapes,
  83. level_start_index=level_start_index,
  84. valid_ratios=valid_ratios,
  85. reference_points=reference_points_input,
  86. **kwargs)
  87. if reg_branches is not None:
  88. tmp = reg_branches[lid](query)
  89. assert reference_points.shape[-1] == 4
  90. new_reference_points = tmp + inverse_sigmoid(
  91. reference_points, eps=1e-3)
  92. new_reference_points = new_reference_points.sigmoid()
  93. reference_points = new_reference_points.detach()
  94. if self.return_intermediate:
  95. intermediate.append(self.norm(query))
  96. intermediate_reference_points.append(new_reference_points)
  97. # NOTE this is for the "Look Forward Twice" module,
  98. # in the DeformDETR, reference_points was appended.
  99. if self.return_intermediate:
  100. return torch.stack(intermediate), torch.stack(
  101. intermediate_reference_points)
  102. return query, reference_points
  103. class CdnQueryGenerator(BaseModule):
  104. """Implement query generator of the Contrastive denoising (CDN) proposed in
  105. `DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object
  106. Detection <https://arxiv.org/abs/2203.03605>`_
  107. Code is modified from the `official github repo
  108. <https://github.com/IDEA-Research/DINO>`_.
  109. Args:
  110. num_classes (int): Number of object classes.
  111. embed_dims (int): The embedding dimensions of the generated queries.
  112. num_matching_queries (int): The queries number of the matching part.
  113. Used for generating dn_mask.
  114. label_noise_scale (float): The scale of label noise, defaults to 0.5.
  115. box_noise_scale (float): The scale of box noise, defaults to 1.0.
  116. group_cfg (:obj:`ConfigDict` or dict, optional): The config of the
  117. denoising queries grouping, includes `dynamic`, `num_dn_queries`,
  118. and `num_groups`. Two grouping strategies, 'static dn groups' and
  119. 'dynamic dn groups', are supported. When `dynamic` is `False`,
  120. the `num_groups` should be set, and the number of denoising query
  121. groups will always be `num_groups`. When `dynamic` is `True`, the
  122. `num_dn_queries` should be set, and the group number will be
  123. dynamic to ensure that the denoising queries number will not exceed
  124. `num_dn_queries` to prevent large fluctuations of memory. Defaults
  125. to `None`.
  126. """
  127. def __init__(self,
  128. num_classes: int,
  129. embed_dims: int,
  130. num_matching_queries: int,
  131. label_noise_scale: float = 0.5,
  132. box_noise_scale: float = 1.0,
  133. group_cfg: OptConfigType = None) -> None:
  134. super().__init__()
  135. self.num_classes = num_classes
  136. self.embed_dims = embed_dims
  137. self.num_matching_queries = num_matching_queries
  138. self.label_noise_scale = label_noise_scale
  139. self.box_noise_scale = box_noise_scale
  140. # prepare grouping strategy
  141. group_cfg = {} if group_cfg is None else group_cfg
  142. self.dynamic_dn_groups = group_cfg.get('dynamic', True)
  143. if self.dynamic_dn_groups:
  144. if 'num_dn_queries' not in group_cfg:
  145. warnings.warn("'num_dn_queries' should be set when using "
  146. 'dynamic dn groups, use 100 as default.')
  147. self.num_dn_queries = group_cfg.get('num_dn_queries', 100)
  148. assert isinstance(self.num_dn_queries, int), \
  149. f'Expected the num_dn_queries to have type int, but got ' \
  150. f'{self.num_dn_queries}({type(self.num_dn_queries)}). '
  151. else:
  152. assert 'num_groups' in group_cfg, \
  153. 'num_groups should be set when using static dn groups'
  154. self.num_groups = group_cfg['num_groups']
  155. assert isinstance(self.num_groups, int), \
  156. f'Expected the num_groups to have type int, but got ' \
  157. f'{self.num_groups}({type(self.num_groups)}). '
  158. # NOTE The original repo of DINO set the num_embeddings 92 for coco,
  159. # 91 (0~90) of which represents target classes and the 92 (91)
  160. # indicates `Unknown` class. However, the embedding of `unknown` class
  161. # is not used in the original DINO.
  162. # TODO: num_classes + 1 or num_classes ?
  163. self.label_embedding = nn.Embedding(self.num_classes, self.embed_dims)
  164. def __call__(self, batch_data_samples: SampleList) -> tuple:
  165. """Generate contrastive denoising (cdn) queries with ground truth.
  166. Descriptions of the Number Values in code and comments:
  167. - num_target_total: the total target number of the input batch
  168. samples.
  169. - max_num_target: the max target number of the input batch samples.
  170. - num_noisy_targets: the total targets number after adding noise,
  171. i.e., num_target_total * num_groups * 2.
  172. - num_denoising_queries: the length of the output batched queries,
  173. i.e., max_num_target * num_groups * 2.
  174. NOTE The format of input bboxes in batch_data_samples is unnormalized
  175. (x, y, x, y), and the output bbox queries are embedded by normalized
  176. (cx, cy, w, h) format bboxes going through inverse_sigmoid.
  177. Args:
  178. batch_data_samples (list[:obj:`DetDataSample`]): List of the batch
  179. data samples, each includes `gt_instance` which has attributes
  180. `bboxes` and `labels`. The `bboxes` has unnormalized coordinate
  181. format (x, y, x, y).
  182. Returns:
  183. tuple: The outputs of the dn query generator.
  184. - dn_label_query (Tensor): The output content queries for denoising
  185. part, has shape (bs, num_denoising_queries, dim), where
  186. `num_denoising_queries = max_num_target * num_groups * 2`.
  187. - dn_bbox_query (Tensor): The output reference bboxes as positions
  188. of queries for denoising part, which are embedded by normalized
  189. (cx, cy, w, h) format bboxes going through inverse_sigmoid, has
  190. shape (bs, num_denoising_queries, 4) with the last dimension
  191. arranged as (cx, cy, w, h).
  192. - attn_mask (Tensor): The attention mask to prevent information
  193. leakage from different denoising groups and matching parts,
  194. will be used as `self_attn_mask` of the `decoder`, has shape
  195. (num_queries_total, num_queries_total), where `num_queries_total`
  196. is the sum of `num_denoising_queries` and `num_matching_queries`.
  197. - dn_meta (Dict[str, int]): The dictionary saves information about
  198. group collation, including 'num_denoising_queries' and
  199. 'num_denoising_groups'. It will be used for split outputs of
  200. denoising and matching parts and loss calculation.
  201. """
  202. # normalize bbox and collate ground truth (gt)
  203. gt_labels_list = []
  204. gt_bboxes_list = []
  205. for sample in batch_data_samples:
  206. img_h, img_w = sample.img_shape
  207. bboxes = sample.gt_instances.bboxes
  208. factor = bboxes.new_tensor([img_w, img_h, img_w,
  209. img_h]).unsqueeze(0)
  210. bboxes_normalized = bboxes / factor
  211. gt_bboxes_list.append(bboxes_normalized)
  212. gt_labels_list.append(sample.gt_instances.labels)
  213. gt_labels = torch.cat(gt_labels_list) # (num_target_total, 4)
  214. gt_bboxes = torch.cat(gt_bboxes_list)
  215. num_target_list = [len(bboxes) for bboxes in gt_bboxes_list]
  216. max_num_target = max(num_target_list)
  217. num_groups = self.get_num_groups(max_num_target)
  218. dn_label_query = self.generate_dn_label_query(gt_labels, num_groups)
  219. dn_bbox_query = self.generate_dn_bbox_query(gt_bboxes, num_groups)
  220. # The `batch_idx` saves the batch index of the corresponding sample
  221. # for each target, has shape (num_target_total).
  222. batch_idx = torch.cat([
  223. torch.full_like(t.long(), i) for i, t in enumerate(gt_labels_list)
  224. ])
  225. dn_label_query, dn_bbox_query = self.collate_dn_queries(
  226. dn_label_query, dn_bbox_query, batch_idx, len(batch_data_samples),
  227. num_groups)
  228. attn_mask = self.generate_dn_mask(
  229. max_num_target, num_groups, device=dn_label_query.device)
  230. dn_meta = dict(
  231. num_denoising_queries=int(max_num_target * 2 * num_groups),
  232. num_denoising_groups=num_groups)
  233. return dn_label_query, dn_bbox_query, attn_mask, dn_meta
  234. def get_num_groups(self, max_num_target: int = None) -> int:
  235. """Calculate denoising query groups number.
  236. Two grouping strategies, 'static dn groups' and 'dynamic dn groups',
  237. are supported. When `self.dynamic_dn_groups` is `False`, the number
  238. of denoising query groups will always be `self.num_groups`. When
  239. `self.dynamic_dn_groups` is `True`, the group number will be dynamic,
  240. ensuring the denoising queries number will not exceed
  241. `self.num_dn_queries` to prevent large fluctuations of memory.
  242. NOTE The `num_group` is shared for different samples in a batch. When
  243. the target numbers in the samples varies, the denoising queries of the
  244. samples containing fewer targets are padded to the max length.
  245. Args:
  246. max_num_target (int, optional): The max target number of the batch
  247. samples. It will only be used when `self.dynamic_dn_groups` is
  248. `True`. Defaults to `None`.
  249. Returns:
  250. int: The denoising group number of the current batch.
  251. """
  252. if self.dynamic_dn_groups:
  253. assert max_num_target is not None, \
  254. 'group_queries should be provided when using ' \
  255. 'dynamic dn groups'
  256. if max_num_target == 0:
  257. num_groups = 1
  258. else:
  259. num_groups = self.num_dn_queries // max_num_target
  260. else:
  261. num_groups = self.num_groups
  262. if num_groups < 1:
  263. num_groups = 1
  264. return int(num_groups)
  265. def generate_dn_label_query(self, gt_labels: Tensor,
  266. num_groups: int) -> Tensor:
  267. """Generate noisy labels and their query embeddings.
  268. The strategy for generating noisy labels is: Randomly choose labels of
  269. `self.label_noise_scale * 0.5` proportion and override each of them
  270. with a random object category label.
  271. NOTE Not add noise to all labels. Besides, the `self.label_noise_scale
  272. * 0.5` arg is the ratio of the chosen positions, which is higher than
  273. the actual proportion of noisy labels, because the labels to override
  274. may be correct. And the gap becomes larger as the number of target
  275. categories decreases. The users should notice this and modify the scale
  276. arg or the corresponding logic according to specific dataset.
  277. Args:
  278. gt_labels (Tensor): The concatenated gt labels of all samples
  279. in the batch, has shape (num_target_total, ) where
  280. `num_target_total = sum(num_target_list)`.
  281. num_groups (int): The number of denoising query groups.
  282. Returns:
  283. Tensor: The query embeddings of noisy labels, has shape
  284. (num_noisy_targets, embed_dims), where `num_noisy_targets =
  285. num_target_total * num_groups * 2`.
  286. """
  287. assert self.label_noise_scale > 0
  288. gt_labels_expand = gt_labels.repeat(2 * num_groups,
  289. 1).view(-1) # Note `* 2` # noqa
  290. p = torch.rand_like(gt_labels_expand.float())
  291. chosen_indice = torch.nonzero(p < (self.label_noise_scale * 0.5)).view(
  292. -1) # Note `* 0.5`
  293. new_labels = torch.randint_like(chosen_indice, 0, self.num_classes)
  294. noisy_labels_expand = gt_labels_expand.scatter(0, chosen_indice,
  295. new_labels)
  296. dn_label_query = self.label_embedding(noisy_labels_expand)
  297. return dn_label_query
  298. def generate_dn_bbox_query(self, gt_bboxes: Tensor,
  299. num_groups: int) -> Tensor:
  300. """Generate noisy bboxes and their query embeddings.
  301. The strategy for generating noisy bboxes is as follow:
  302. .. code:: text
  303. +--------------------+
  304. | negative |
  305. | +----------+ |
  306. | | positive | |
  307. | | +-----|----+------------+
  308. | | | | | |
  309. | +----+-----+ | |
  310. | | | |
  311. +---------+----------+ |
  312. | |
  313. | gt bbox |
  314. | |
  315. | +---------+----------+
  316. | | | |
  317. | | +----+-----+ |
  318. | | | | | |
  319. +-------------|--- +----+ | |
  320. | | positive | |
  321. | +----------+ |
  322. | negative |
  323. +--------------------+
  324. The random noise is added to the top-left and down-right point
  325. positions, hence, normalized (x, y, x, y) format of bboxes are
  326. required. The noisy bboxes of positive queries have the points
  327. both within the inner square, while those of negative queries
  328. have the points both between the inner and outer squares.
  329. Besides, the length of outer square is twice as long as that of
  330. the inner square, i.e., self.box_noise_scale * w_or_h / 2.
  331. NOTE The noise is added to all the bboxes. Moreover, there is still
  332. unconsidered case when one point is within the positive square and
  333. the others is between the inner and outer squares.
  334. Args:
  335. gt_bboxes (Tensor): The concatenated gt bboxes of all samples
  336. in the batch, has shape (num_target_total, 4) with the last
  337. dimension arranged as (cx, cy, w, h) where
  338. `num_target_total = sum(num_target_list)`.
  339. num_groups (int): The number of denoising query groups.
  340. Returns:
  341. Tensor: The output noisy bboxes, which are embedded by normalized
  342. (cx, cy, w, h) format bboxes going through inverse_sigmoid, has
  343. shape (num_noisy_targets, 4) with the last dimension arranged as
  344. (cx, cy, w, h), where
  345. `num_noisy_targets = num_target_total * num_groups * 2`.
  346. """
  347. assert self.box_noise_scale > 0
  348. device = gt_bboxes.device
  349. # expand gt_bboxes as groups
  350. gt_bboxes_expand = gt_bboxes.repeat(2 * num_groups, 1) # xyxy
  351. # obtain index of negative queries in gt_bboxes_expand
  352. positive_idx = torch.arange(
  353. len(gt_bboxes), dtype=torch.long, device=device)
  354. positive_idx = positive_idx.unsqueeze(0).repeat(num_groups, 1)
  355. positive_idx += 2 * len(gt_bboxes) * torch.arange(
  356. num_groups, dtype=torch.long, device=device)[:, None]
  357. positive_idx = positive_idx.flatten()
  358. negative_idx = positive_idx + len(gt_bboxes)
  359. # determine the sign of each element in the random part of the added
  360. # noise to be positive or negative randomly.
  361. rand_sign = torch.randint_like(
  362. gt_bboxes_expand, low=0, high=2,
  363. dtype=torch.float32) * 2.0 - 1.0 # [low, high), 1 or -1, randomly
  364. # calculate the random part of the added noise
  365. rand_part = torch.rand_like(gt_bboxes_expand) # [0, 1)
  366. rand_part[negative_idx] += 1.0 # pos: [0, 1); neg: [1, 2)
  367. rand_part *= rand_sign # pos: (-1, 1); neg: (-2, -1] U [1, 2)
  368. # add noise to the bboxes
  369. bboxes_whwh = bbox_xyxy_to_cxcywh(gt_bboxes_expand)[:, 2:].repeat(1, 2)
  370. noisy_bboxes_expand = gt_bboxes_expand + torch.mul(
  371. rand_part, bboxes_whwh) * self.box_noise_scale / 2 # xyxy
  372. noisy_bboxes_expand = noisy_bboxes_expand.clamp(min=0.0, max=1.0)
  373. noisy_bboxes_expand = bbox_xyxy_to_cxcywh(noisy_bboxes_expand)
  374. dn_bbox_query = inverse_sigmoid(noisy_bboxes_expand, eps=1e-3)
  375. return dn_bbox_query
  376. def collate_dn_queries(self, input_label_query: Tensor,
  377. input_bbox_query: Tensor, batch_idx: Tensor,
  378. batch_size: int, num_groups: int) -> Tuple[Tensor]:
  379. """Collate generated queries to obtain batched dn queries.
  380. The strategy for query collation is as follow:
  381. .. code:: text
  382. input_queries (num_target_total, query_dim)
  383. P_A1 P_B1 P_B2 N_A1 N_B1 N_B2 P'A1 P'B1 P'B2 N'A1 N'B1 N'B2
  384. |________ group1 ________| |________ group2 ________|
  385. |
  386. V
  387. P_A1 Pad0 N_A1 Pad0 P'A1 Pad0 N'A1 Pad0
  388. P_B1 P_B2 N_B1 N_B2 P'B1 P'B2 N'B1 N'B2
  389. |____ group1 ____| |____ group2 ____|
  390. batched_queries (batch_size, max_num_target, query_dim)
  391. where query_dim is 4 for bbox and self.embed_dims for label.
  392. Notation: _-group 1; '-group 2;
  393. A-Sample1(has 1 target); B-sample2(has 2 targets)
  394. Args:
  395. input_label_query (Tensor): The generated label queries of all
  396. targets, has shape (num_target_total, embed_dims) where
  397. `num_target_total = sum(num_target_list)`.
  398. input_bbox_query (Tensor): The generated bbox queries of all
  399. targets, has shape (num_target_total, 4) with the last
  400. dimension arranged as (cx, cy, w, h).
  401. batch_idx (Tensor): The batch index of the corresponding sample
  402. for each target, has shape (num_target_total).
  403. batch_size (int): The size of the input batch.
  404. num_groups (int): The number of denoising query groups.
  405. Returns:
  406. tuple[Tensor]: Output batched label and bbox queries.
  407. - batched_label_query (Tensor): The output batched label queries,
  408. has shape (batch_size, max_num_target, embed_dims).
  409. - batched_bbox_query (Tensor): The output batched bbox queries,
  410. has shape (batch_size, max_num_target, 4) with the last dimension
  411. arranged as (cx, cy, w, h).
  412. """
  413. device = input_label_query.device
  414. num_target_list = [
  415. torch.sum(batch_idx == idx) for idx in range(batch_size)
  416. ]
  417. max_num_target = max(num_target_list)
  418. num_denoising_queries = int(max_num_target * 2 * num_groups)
  419. map_query_index = torch.cat([
  420. torch.arange(num_target, device=device)
  421. for num_target in num_target_list
  422. ])
  423. map_query_index = torch.cat([
  424. map_query_index + max_num_target * i for i in range(2 * num_groups)
  425. ]).long()
  426. batch_idx_expand = batch_idx.repeat(2 * num_groups, 1).view(-1)
  427. mapper = (batch_idx_expand, map_query_index)
  428. batched_label_query = torch.zeros(
  429. batch_size, num_denoising_queries, self.embed_dims, device=device)
  430. batched_bbox_query = torch.zeros(
  431. batch_size, num_denoising_queries, 4, device=device)
  432. batched_label_query[mapper] = input_label_query
  433. batched_bbox_query[mapper] = input_bbox_query
  434. return batched_label_query, batched_bbox_query
  435. def generate_dn_mask(self, max_num_target: int, num_groups: int,
  436. device: Union[torch.device, str]) -> Tensor:
  437. """Generate attention mask to prevent information leakage from
  438. different denoising groups and matching parts.
  439. .. code:: text
  440. 0 0 0 0 1 1 1 1 0 0 0 0 0
  441. 0 0 0 0 1 1 1 1 0 0 0 0 0
  442. 0 0 0 0 1 1 1 1 0 0 0 0 0
  443. 0 0 0 0 1 1 1 1 0 0 0 0 0
  444. 1 1 1 1 0 0 0 0 0 0 0 0 0
  445. 1 1 1 1 0 0 0 0 0 0 0 0 0
  446. 1 1 1 1 0 0 0 0 0 0 0 0 0
  447. 1 1 1 1 0 0 0 0 0 0 0 0 0
  448. 1 1 1 1 1 1 1 1 0 0 0 0 0
  449. 1 1 1 1 1 1 1 1 0 0 0 0 0
  450. 1 1 1 1 1 1 1 1 0 0 0 0 0
  451. 1 1 1 1 1 1 1 1 0 0 0 0 0
  452. 1 1 1 1 1 1 1 1 0 0 0 0 0
  453. max_num_target |_| |_________| num_matching_queries
  454. |_____________| num_denoising_queries
  455. 1 -> True (Masked), means 'can not see'.
  456. 0 -> False (UnMasked), means 'can see'.
  457. Args:
  458. max_num_target (int): The max target number of the input batch
  459. samples.
  460. num_groups (int): The number of denoising query groups.
  461. device (obj:`device` or str): The device of generated mask.
  462. Returns:
  463. Tensor: The attention mask to prevent information leakage from
  464. different denoising groups and matching parts, will be used as
  465. `self_attn_mask` of the `decoder`, has shape (num_queries_total,
  466. num_queries_total), where `num_queries_total` is the sum of
  467. `num_denoising_queries` and `num_matching_queries`.
  468. """
  469. num_denoising_queries = int(max_num_target * 2 * num_groups)
  470. num_queries_total = num_denoising_queries + self.num_matching_queries
  471. attn_mask = torch.zeros(
  472. num_queries_total,
  473. num_queries_total,
  474. device=device,
  475. dtype=torch.bool)
  476. # Make the matching part cannot see the denoising groups
  477. attn_mask[num_denoising_queries:, :num_denoising_queries] = True
  478. # Make the denoising groups cannot see each other
  479. for i in range(num_groups):
  480. # Mask rows of one group per step.
  481. row_scope = slice(max_num_target * 2 * i,
  482. max_num_target * 2 * (i + 1))
  483. left_scope = slice(max_num_target * 2 * i)
  484. right_scope = slice(max_num_target * 2 * (i + 1),
  485. num_denoising_queries)
  486. attn_mask[row_scope, right_scope] = True
  487. attn_mask[row_scope, left_scope] = True
  488. return attn_mask