rtdetr_head.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from typing import Dict, List, Tuple
  2. import torch
  3. from torch import Tensor
  4. from mmdet.registry import MODELS
  5. from mmdet.structures import SampleList
  6. from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_overlaps
  7. from mmdet.utils import InstanceList, reduce_mean
  8. from ..losses import VarifocalLoss
  9. from .dino_head import DINOHead
  10. @MODELS.register_module()
  11. class RTDETRHead(DINOHead):
  12. r"""Head of the DETRs Beat YOLOs on Real-time Object Detection
  13. The loss frame is implemented in DinoHead
  14. Rtdetr head have different implementation in loss_dn_single and loss_feature_single
  15. """
  16. #without denoising
  17. def loss(self,hidden_states: Tensor, references: Tensor,
  18. enc_outputs_class: Tensor, enc_outputs_coord: Tensor,
  19. batch_data_samples: SampleList, dn_meta: Dict[str, int]):
  20. batch_gt_instances = []
  21. batch_img_metas = []
  22. for data_sample in batch_data_samples:
  23. batch_img_metas.append(data_sample.metainfo)
  24. batch_gt_instances.append(data_sample.gt_instances)
  25. loss_inputs = hidden_states ,references,enc_outputs_class, enc_outputs_coord,\
  26. batch_gt_instances, batch_img_metas, dn_meta
  27. losses = self.loss_by_feat(*loss_inputs)
  28. return losses
  29. def forward(self, hidden_states: Tensor,
  30. references: Tensor) -> Tuple[Tensor]:
  31. """Forward function.
  32. Args:
  33. hidden_states (Tensor): Hidden states output from each decoder
  34. layer, has shape (num_decoder_layers, bs, num_queries, dim).
  35. references (list[Tensor]): List of the reference from the decoder.
  36. The first reference is the `init_reference` (initial) and the
  37. other num_decoder_layers(6) references are `inter_references`
  38. (intermediate). The `init_reference` has shape (bs,
  39. num_queries, 4) when `as_two_stage` of the detector is `True`,
  40. otherwise (bs, num_queries, 2). Each `inter_reference` has
  41. shape (bs, num_queries, 4) when `with_box_refine` of the
  42. detector is `True`, otherwise (bs, num_queries, 2). The
  43. coordinates are arranged as (cx, cy) when the last dimension is
  44. 2, and (cx, cy, w, h) when it is 4.
  45. Returns:
  46. tuple[Tensor]: results of head containing the following tensor.
  47. - all_layers_outputs_classes (Tensor): Outputs from the
  48. classification head, has shape (num_decoder_layers, bs,
  49. num_queries, cls_out_channels).
  50. - all_layers_outputs_coords (Tensor): Sigmoid outputs from the
  51. regression head with normalized coordinate format (cx, cy, w,
  52. h), has shape (num_decoder_layers, bs, num_queries, 4) with the
  53. last dimension arranged as (cx, cy, w, h).
  54. """
  55. # all_layers_outputs_classes = []
  56. # all_layers_outputs_coords = []
  57. # for layer_id in range(hidden_states.shape[0]):
  58. # reference = inverse_sigmoid(references[layer_id])
  59. # # NOTE The last reference will not be used.
  60. # hidden_state = hidden_states[layer_id]
  61. # outputs_class = self.cls_branches[layer_id](hidden_state)
  62. # tmp_reg_preds = self.reg_branches[layer_id](hidden_state)
  63. # if reference.shape[-1] == 4:
  64. # # When `layer` is 0 and `as_two_stage` of the detector
  65. # # is `True`, or when `layer` is greater than 0 and
  66. # # `with_box_refine` of the detector is `True`.
  67. # tmp_reg_preds += reference
  68. # else:
  69. # # When `layer` is 0 and `as_two_stage` of the detector
  70. # # is `False`, or when `layer` is greater than 0 and
  71. # # `with_box_refine` of the detector is `False`.
  72. # assert reference.shape[-1] == 2
  73. # tmp_reg_preds[..., :2] += reference
  74. # outputs_coord = tmp_reg_preds.sigmoid()
  75. # all_layers_outputs_classes.append(outputs_class)
  76. # all_layers_outputs_coords.append(outputs_coord)
  77. # all_layers_outputs_classes = torch.stack(all_layers_outputs_classes)
  78. # all_layers_outputs_coords = torch.stack(all_layers_outputs_coords)
  79. return hidden_states, references
  80. def loss_by_feat_single(self,
  81. cls_scores:Tensor,
  82. bbox_preds: Tensor,
  83. batch_gt_instances: InstanceList,
  84. batch_img_metas: List[dict]
  85. )->Tuple[Tensor]:
  86. num_imgs = cls_scores.size(0)
  87. cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
  88. bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
  89. cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
  90. batch_gt_instances, batch_img_metas)
  91. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  92. num_total_pos, num_total_neg) = cls_reg_targets
  93. labels = torch.cat(labels_list, 0)
  94. label_weights = torch.cat(label_weights_list, 0)
  95. bbox_targets = torch.cat(bbox_targets_list, 0)
  96. bbox_weights = torch.cat(bbox_weights_list, 0)
  97. # classification loss
  98. cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
  99. # construct weighted avg_factor to match with the official DETR repo
  100. cls_avg_factor = num_total_pos * 1.0 + \
  101. num_total_neg * self.bg_cls_weight
  102. if self.sync_cls_avg_factor:
  103. cls_avg_factor = reduce_mean(
  104. cls_scores.new_tensor([cls_avg_factor]))
  105. cls_avg_factor = max(cls_avg_factor, 1)
  106. if isinstance(self.loss_cls, VarifocalLoss):
  107. bg_class_ind = self.num_classes
  108. pos_inds = ((labels >= 0)
  109. & (labels < bg_class_ind)).nonzero().squeeze(1)
  110. cls_iou_targets = label_weights.new_zeros(cls_scores.shape)
  111. pos_bbox_targets = bbox_targets[pos_inds]
  112. pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets)
  113. pos_bbox_pred = bbox_preds.reshape(-1, 4)[pos_inds]
  114. pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred)
  115. pos_labels = labels[pos_inds]
  116. cls_iou_targets[pos_inds, pos_labels] = bbox_overlaps(
  117. pos_decode_bbox_pred.detach(),
  118. pos_decode_bbox_targets,
  119. is_aligned=True)
  120. loss_cls = self.loss_cls(
  121. cls_scores, cls_iou_targets, avg_factor=cls_avg_factor)
  122. else:
  123. loss_cls = self.loss_cls(
  124. cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
  125. # Compute the average number of gt boxes across all gpus, for
  126. # normalization purposes
  127. num_total_pos = loss_cls.new_tensor([num_total_pos])
  128. num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
  129. # construct factors used for rescale bboxes
  130. factors = []
  131. for img_meta, bbox_pred in zip(batch_img_metas, bbox_preds):
  132. img_h, img_w, = img_meta['img_shape']
  133. factor = bbox_pred.new_tensor([img_w, img_h, img_w,
  134. img_h]).unsqueeze(0).repeat(
  135. bbox_pred.size(0), 1)
  136. factors.append(factor)
  137. factors = torch.cat(factors, 0)
  138. # DETR regress the relative position of boxes (cxcywh) in the image,
  139. # thus the learning target is normalized by the image size. So here
  140. # we need to re-scale them for calculating IoU loss
  141. bbox_preds = bbox_preds.reshape(-1, 4)
  142. bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
  143. bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
  144. # regression IoU loss, defaultly GIoU loss
  145. loss_iou = self.loss_iou(
  146. bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
  147. # regression L1 loss
  148. loss_bbox = self.loss_bbox(
  149. bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
  150. return loss_cls, loss_bbox, loss_iou
  151. #with denoising, same with loss_feature except using dn target
  152. def _loss_dn_single(self,
  153. dn_cls_scores: Tensor,
  154. dn_bbox_preds: Tensor,
  155. batch_gt_instances: InstanceList,
  156. batch_img_metas: List[Dict],
  157. dn_meta: Dict[str, int]) -> Tuple[Tensor]:
  158. cls_reg_targets=self.get_dn_targets(batch_gt_instances,batch_img_metas,dn_meta)
  159. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  160. num_total_pos, num_total_neg) = cls_reg_targets
  161. labels = torch.cat(labels_list, 0)
  162. label_weights = torch.cat(label_weights_list, 0)
  163. bbox_targets = torch.cat(bbox_targets_list, 0)
  164. bbox_weights=torch.cat(bbox_weights_list,0)
  165. #class score shape is dn_cls_scores.size()/self.cls_out_channels,self.cls_out_channels
  166. cls_scores = dn_cls_scores.reshape(-1, self.cls_out_channels)
  167. # construct weighted avg_factor to match with the official DETR repo
  168. cls_avg_factor = \
  169. num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight
  170. if self.sync_cls_avg_factor:
  171. cls_avg_factor = reduce_mean(
  172. cls_scores.new_tensor([cls_avg_factor]))
  173. cls_avg_factor = max(cls_avg_factor, 1)
  174. if len(cls_scores) > 0:
  175. #RTDETRVarifocalLoss focus more on postive target
  176. #which it can lead model pay more attention on rarely objects
  177. #for barcode we may not use it since we only care about one class
  178. if isinstance(self.loss_cls, VarifocalLoss):
  179. bg_class_ind = self.num_classes
  180. pos_inds = ((labels >= 0)
  181. & (labels < bg_class_ind)).nonzero().squeeze(1)
  182. cls_iou_targets = label_weights.new_zeros(cls_scores.shape)
  183. pos_bbox_targets = bbox_targets[pos_inds]
  184. pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets)
  185. pos_bbox_pred = dn_bbox_preds.reshape(-1, 4)[pos_inds]
  186. pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred)
  187. pos_labels = labels[pos_inds]
  188. cls_iou_targets[pos_inds, pos_labels] = bbox_overlaps(
  189. pos_decode_bbox_pred.detach(),
  190. pos_decode_bbox_targets,
  191. is_aligned=True)
  192. loss_cls = self.loss_cls(
  193. cls_scores, cls_iou_targets, avg_factor=cls_avg_factor)
  194. else:
  195. loss_cls = self.loss_cls(
  196. cls_scores,
  197. labels,
  198. label_weights,
  199. avg_factor=cls_avg_factor)
  200. else:
  201. loss_cls = torch.zeros(
  202. 1, dtype=cls_scores.dtype, device=cls_scores.device)
  203. # Compute the average number of gt boxes across all gpus, for
  204. # normalization purposes
  205. num_total_pos = loss_cls.new_tensor([num_total_pos])
  206. num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
  207. # construct factors used for rescale bboxes
  208. factors = []
  209. for img_meta, bbox_pred in zip(batch_img_metas, dn_bbox_preds):
  210. img_h, img_w = img_meta['img_shape']
  211. factor = bbox_pred.new_tensor([img_w, img_h, img_w,
  212. img_h]).unsqueeze(0).repeat(
  213. bbox_pred.size(0), 1)
  214. factors.append(factor)
  215. factors = torch.cat(factors)
  216. # DETR regress the relative position of boxes (cxcywh) in the image,
  217. # thus the learning target is normalized by the image size. So here
  218. # we need to re-scale them for calculating IoU loss
  219. bbox_preds = dn_bbox_preds.reshape(-1, 4)
  220. bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
  221. bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
  222. # regression IoU loss, defaultly GIoU loss
  223. loss_iou = self.loss_iou(
  224. bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
  225. # regression L1 loss
  226. loss_bbox = self.loss_bbox(
  227. bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
  228. return loss_cls, loss_bbox, loss_iou