dino_utils.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import torch
  2. import torchvision.ops._box_convert as box_op
  3. from mmdet.models.layers.transformer.utils import inverse_sigmoid
  4. def get_contrastive_denoising_training_group\
  5. (targets,
  6. num_classes,
  7. num_queries,
  8. class_embed,
  9. num_denoising=100,
  10. label_noise_ratio=0.5,
  11. box_noise_scale=1.0,):
  12. if num_denoising <= 0:
  13. return None, None, None, None
  14. num_gts = [len(t['labels']) for t in targets]
  15. device = targets[0]['labels'].device
  16. max_gt_num = max(num_gts)
  17. if max_gt_num == 0:
  18. return None, None, None, None
  19. num_group = num_denoising // max_gt_num
  20. num_group = 1 if num_group == 0 else num_group
  21. # pad gt to max_num of a batch
  22. bs = len(num_gts)
  23. input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device)
  24. input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device)
  25. pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device)
  26. for i in range(bs):
  27. num_gt = num_gts[i]
  28. if num_gt > 0:
  29. input_query_class[i, :num_gt] = targets[i]['labels']
  30. input_query_bbox[i, :num_gt] = targets[i]['boxes']
  31. pad_gt_mask[i, :num_gt] = 1
  32. # each group has positive and negative queries.
  33. input_query_class = input_query_class.tile([1, 2 * num_group])
  34. input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
  35. pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
  36. # positive and negative mask
  37. negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device)
  38. negative_gt_mask[:, max_gt_num:] = 1
  39. negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
  40. positive_gt_mask = 1 - negative_gt_mask
  41. # contrastive denoising training positive index
  42. positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
  43. dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
  44. dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts])
  45. # total denoising queries
  46. num_denoising = int(max_gt_num * 2 * num_group)
  47. if label_noise_ratio > 0:
  48. mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
  49. # randomly put a new one here
  50. new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
  51. input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
  52. # if label_noise_ratio > 0:
  53. # input_query_class = input_query_class.flatten()
  54. # pad_gt_mask = pad_gt_mask.flatten()
  55. # # half of bbox prob
  56. # # mask = torch.rand(input_query_class.shape, device=device) < (label_noise_ratio * 0.5)
  57. # mask = torch.rand_like(input_query_class) < (label_noise_ratio * 0.5)
  58. # chosen_idx = torch.nonzero(mask * pad_gt_mask).squeeze(-1)
  59. # # randomly put a new one here
  60. # new_label = torch.randint_like(chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
  61. # # input_query_class.scatter_(dim=0, index=chosen_idx, value=new_label)
  62. # input_query_class[chosen_idx] = new_label
  63. # input_query_class = input_query_class.reshape(bs, num_denoising)
  64. # pad_gt_mask = pad_gt_mask.reshape(bs, num_denoising)
  65. if box_noise_scale > 0:
  66. known_bbox=box_op._box_cxcywh_to_xyxy(input_query_bbox)
  67. diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
  68. rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
  69. rand_part = torch.rand_like(input_query_bbox)
  70. rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
  71. rand_part *= rand_sign
  72. known_bbox += rand_part * diff
  73. known_bbox.clip_(min=0.0, max=1.0)
  74. input_query_bbox=box_op._box_xyxy_to_cxcywh(known_bbox)
  75. input_query_bbox = inverse_sigmoid(input_query_bbox)
  76. # class_embed = torch.concat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=device)])
  77. # input_query_class = torch.gather(
  78. # class_embed, input_query_class.flatten(),
  79. # axis=0).reshape(bs, num_denoising, -1)
  80. # input_query_class = class_embed(input_query_class.flatten()).reshape(bs, num_denoising, -1)
  81. input_query_class = class_embed(input_query_class)
  82. tgt_size = num_denoising + num_queries
  83. # attn_mask = torch.ones([tgt_size, tgt_size], device=device) < 0
  84. attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device)
  85. # match query cannot see the reconstruction
  86. attn_mask[num_denoising:, :num_denoising] = True
  87. # reconstruct cannot see each other
  88. for i in range(num_group):
  89. if i == 0:
  90. attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
  91. if i == num_group - 1:
  92. attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * i * 2] = True
  93. else:
  94. attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
  95. attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * 2 * i] = True
  96. dn_meta = {
  97. "dn_positive_idx": dn_positive_idx,
  98. "dn_num_group": num_group,
  99. "dn_num_split": [num_denoising, num_queries]
  100. }
  101. # print(input_query_class.shape) # torch.Size([4, 196, 256])
  102. # print(input_query_bbox.shape) # torch.Size([4, 196, 4])
  103. # print(attn_mask.shape) # torch.Size([496, 496])
  104. return input_query_class, input_query_bbox, attn_mask, dn_meta