from typing import Dict, List, Tuple import torch from torch import Tensor from mmdet.registry import MODELS from mmdet.structures import SampleList from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_overlaps from mmdet.utils import InstanceList, reduce_mean from ..losses import VarifocalLoss from .dino_head import DINOHead @MODELS.register_module() class RTDETRHead(DINOHead): r"""Head of the DETRs Beat YOLOs on Real-time Object Detection The loss frame is implemented in DinoHead Rtdetr head have different implementation in loss_dn_single and loss_feature_single """ #without denoising def loss(self,hidden_states: Tensor, references: Tensor, enc_outputs_class: Tensor, enc_outputs_coord: Tensor, batch_data_samples: SampleList, dn_meta: Dict[str, int]): batch_gt_instances = [] batch_img_metas = [] for data_sample in batch_data_samples: batch_img_metas.append(data_sample.metainfo) batch_gt_instances.append(data_sample.gt_instances) loss_inputs = hidden_states ,references,enc_outputs_class, enc_outputs_coord,\ batch_gt_instances, batch_img_metas, dn_meta losses = self.loss_by_feat(*loss_inputs) return losses def forward(self, hidden_states: Tensor, references: Tensor) -> Tuple[Tensor]: """Forward function. Args: hidden_states (Tensor): Hidden states output from each decoder layer, has shape (num_decoder_layers, bs, num_queries, dim). references (list[Tensor]): List of the reference from the decoder. The first reference is the `init_reference` (initial) and the other num_decoder_layers(6) references are `inter_references` (intermediate). The `init_reference` has shape (bs, num_queries, 4) when `as_two_stage` of the detector is `True`, otherwise (bs, num_queries, 2). Each `inter_reference` has shape (bs, num_queries, 4) when `with_box_refine` of the detector is `True`, otherwise (bs, num_queries, 2). The coordinates are arranged as (cx, cy) when the last dimension is 2, and (cx, cy, w, h) when it is 4. Returns: tuple[Tensor]: results of head containing the following tensor. - all_layers_outputs_classes (Tensor): Outputs from the classification head, has shape (num_decoder_layers, bs, num_queries, cls_out_channels). - all_layers_outputs_coords (Tensor): Sigmoid outputs from the regression head with normalized coordinate format (cx, cy, w, h), has shape (num_decoder_layers, bs, num_queries, 4) with the last dimension arranged as (cx, cy, w, h). """ # all_layers_outputs_classes = [] # all_layers_outputs_coords = [] # for layer_id in range(hidden_states.shape[0]): # reference = inverse_sigmoid(references[layer_id]) # # NOTE The last reference will not be used. # hidden_state = hidden_states[layer_id] # outputs_class = self.cls_branches[layer_id](hidden_state) # tmp_reg_preds = self.reg_branches[layer_id](hidden_state) # if reference.shape[-1] == 4: # # When `layer` is 0 and `as_two_stage` of the detector # # is `True`, or when `layer` is greater than 0 and # # `with_box_refine` of the detector is `True`. # tmp_reg_preds += reference # else: # # When `layer` is 0 and `as_two_stage` of the detector # # is `False`, or when `layer` is greater than 0 and # # `with_box_refine` of the detector is `False`. # assert reference.shape[-1] == 2 # tmp_reg_preds[..., :2] += reference # outputs_coord = tmp_reg_preds.sigmoid() # all_layers_outputs_classes.append(outputs_class) # all_layers_outputs_coords.append(outputs_coord) # all_layers_outputs_classes = torch.stack(all_layers_outputs_classes) # all_layers_outputs_coords = torch.stack(all_layers_outputs_coords) return hidden_states, references def loss_by_feat_single(self, cls_scores:Tensor, bbox_preds: Tensor, batch_gt_instances: InstanceList, batch_img_metas: List[dict] )->Tuple[Tensor]: num_imgs = cls_scores.size(0) cls_scores_list = [cls_scores[i] for i in range(num_imgs)] bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, batch_gt_instances, batch_img_metas) (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets labels = torch.cat(labels_list, 0) label_weights = torch.cat(label_weights_list, 0) bbox_targets = torch.cat(bbox_targets_list, 0) bbox_weights = torch.cat(bbox_weights_list, 0) # classification loss cls_scores = cls_scores.reshape(-1, self.cls_out_channels) # construct weighted avg_factor to match with the official DETR repo cls_avg_factor = num_total_pos * 1.0 + \ num_total_neg * self.bg_cls_weight if self.sync_cls_avg_factor: cls_avg_factor = reduce_mean( cls_scores.new_tensor([cls_avg_factor])) cls_avg_factor = max(cls_avg_factor, 1) if isinstance(self.loss_cls, VarifocalLoss): bg_class_ind = self.num_classes pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) cls_iou_targets = label_weights.new_zeros(cls_scores.shape) pos_bbox_targets = bbox_targets[pos_inds] pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets) pos_bbox_pred = bbox_preds.reshape(-1, 4)[pos_inds] pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred) pos_labels = labels[pos_inds] cls_iou_targets[pos_inds, pos_labels] = bbox_overlaps( pos_decode_bbox_pred.detach(), pos_decode_bbox_targets, is_aligned=True) loss_cls = self.loss_cls( cls_scores, cls_iou_targets, avg_factor=cls_avg_factor) else: loss_cls = self.loss_cls( cls_scores, labels, label_weights, avg_factor=cls_avg_factor) # Compute the average number of gt boxes across all gpus, for # normalization purposes num_total_pos = loss_cls.new_tensor([num_total_pos]) num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() # construct factors used for rescale bboxes factors = [] for img_meta, bbox_pred in zip(batch_img_metas, bbox_preds): img_h, img_w, = img_meta['img_shape'] factor = bbox_pred.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0).repeat( bbox_pred.size(0), 1) factors.append(factor) factors = torch.cat(factors, 0) # DETR regress the relative position of boxes (cxcywh) in the image, # thus the learning target is normalized by the image size. So here # we need to re-scale them for calculating IoU loss bbox_preds = bbox_preds.reshape(-1, 4) bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors # regression IoU loss, defaultly GIoU loss loss_iou = self.loss_iou( bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) # regression L1 loss loss_bbox = self.loss_bbox( bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) return loss_cls, loss_bbox, loss_iou #with denoising, same with loss_feature except using dn target def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, batch_gt_instances: InstanceList, batch_img_metas: List[Dict], dn_meta: Dict[str, int]) -> Tuple[Tensor]: cls_reg_targets=self.get_dn_targets(batch_gt_instances,batch_img_metas,dn_meta) (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets labels = torch.cat(labels_list, 0) label_weights = torch.cat(label_weights_list, 0) bbox_targets = torch.cat(bbox_targets_list, 0) bbox_weights=torch.cat(bbox_weights_list,0) #class score shape is dn_cls_scores.size()/self.cls_out_channels,self.cls_out_channels cls_scores = dn_cls_scores.reshape(-1, self.cls_out_channels) # construct weighted avg_factor to match with the official DETR repo cls_avg_factor = \ num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight if self.sync_cls_avg_factor: cls_avg_factor = reduce_mean( cls_scores.new_tensor([cls_avg_factor])) cls_avg_factor = max(cls_avg_factor, 1) if len(cls_scores) > 0: #RTDETRVarifocalLoss focus more on postive target #which it can lead model pay more attention on rarely objects #for barcode we may not use it since we only care about one class if isinstance(self.loss_cls, VarifocalLoss): bg_class_ind = self.num_classes pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) cls_iou_targets = label_weights.new_zeros(cls_scores.shape) pos_bbox_targets = bbox_targets[pos_inds] pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets) pos_bbox_pred = dn_bbox_preds.reshape(-1, 4)[pos_inds] pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred) pos_labels = labels[pos_inds] cls_iou_targets[pos_inds, pos_labels] = bbox_overlaps( pos_decode_bbox_pred.detach(), pos_decode_bbox_targets, is_aligned=True) loss_cls = self.loss_cls( cls_scores, cls_iou_targets, avg_factor=cls_avg_factor) else: loss_cls = self.loss_cls( cls_scores, labels, label_weights, avg_factor=cls_avg_factor) else: loss_cls = torch.zeros( 1, dtype=cls_scores.dtype, device=cls_scores.device) # Compute the average number of gt boxes across all gpus, for # normalization purposes num_total_pos = loss_cls.new_tensor([num_total_pos]) num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() # construct factors used for rescale bboxes factors = [] for img_meta, bbox_pred in zip(batch_img_metas, dn_bbox_preds): img_h, img_w = img_meta['img_shape'] factor = bbox_pred.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0).repeat( bbox_pred.size(0), 1) factors.append(factor) factors = torch.cat(factors) # DETR regress the relative position of boxes (cxcywh) in the image, # thus the learning target is normalized by the image size. So here # we need to re-scale them for calculating IoU loss bbox_preds = dn_bbox_preds.reshape(-1, 4) bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors # regression IoU loss, defaultly GIoU loss loss_iou = self.loss_iou( bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) # regression L1 loss loss_bbox = self.loss_bbox( bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) return loss_cls, loss_bbox, loss_iou