123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- from typing import Dict, List, Tuple
- import torch
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.structures import SampleList
- from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_overlaps
- from mmdet.utils import InstanceList, reduce_mean
- from ..losses import VarifocalLoss
- from .dino_head import DINOHead
- @MODELS.register_module()
- class RTDETRHead(DINOHead):
- r"""Head of the DETRs Beat YOLOs on Real-time Object Detection
- The loss frame is implemented in DinoHead
- Rtdetr head have different implementation in loss_dn_single and loss_feature_single
- """
- #without denoising
- def loss(self,hidden_states: Tensor, references: Tensor,
- enc_outputs_class: Tensor, enc_outputs_coord: Tensor,
- batch_data_samples: SampleList, dn_meta: Dict[str, int]):
- batch_gt_instances = []
- batch_img_metas = []
- for data_sample in batch_data_samples:
- batch_img_metas.append(data_sample.metainfo)
- batch_gt_instances.append(data_sample.gt_instances)
- loss_inputs = hidden_states ,references,enc_outputs_class, enc_outputs_coord,\
- batch_gt_instances, batch_img_metas, dn_meta
- losses = self.loss_by_feat(*loss_inputs)
- return losses
- def forward(self, hidden_states: Tensor,
- references: Tensor) -> Tuple[Tensor]:
- """Forward function.
- Args:
- hidden_states (Tensor): Hidden states output from each decoder
- layer, has shape (num_decoder_layers, bs, num_queries, dim).
- references (list[Tensor]): List of the reference from the decoder.
- The first reference is the `init_reference` (initial) and the
- other num_decoder_layers(6) references are `inter_references`
- (intermediate). The `init_reference` has shape (bs,
- num_queries, 4) when `as_two_stage` of the detector is `True`,
- otherwise (bs, num_queries, 2). Each `inter_reference` has
- shape (bs, num_queries, 4) when `with_box_refine` of the
- detector is `True`, otherwise (bs, num_queries, 2). The
- coordinates are arranged as (cx, cy) when the last dimension is
- 2, and (cx, cy, w, h) when it is 4.
- Returns:
- tuple[Tensor]: results of head containing the following tensor.
- - all_layers_outputs_classes (Tensor): Outputs from the
- classification head, has shape (num_decoder_layers, bs,
- num_queries, cls_out_channels).
- - all_layers_outputs_coords (Tensor): Sigmoid outputs from the
- regression head with normalized coordinate format (cx, cy, w,
- h), has shape (num_decoder_layers, bs, num_queries, 4) with the
- last dimension arranged as (cx, cy, w, h).
- """
- # all_layers_outputs_classes = []
- # all_layers_outputs_coords = []
- # for layer_id in range(hidden_states.shape[0]):
- # reference = inverse_sigmoid(references[layer_id])
- # # NOTE The last reference will not be used.
- # hidden_state = hidden_states[layer_id]
- # outputs_class = self.cls_branches[layer_id](hidden_state)
- # tmp_reg_preds = self.reg_branches[layer_id](hidden_state)
- # if reference.shape[-1] == 4:
- # # When `layer` is 0 and `as_two_stage` of the detector
- # # is `True`, or when `layer` is greater than 0 and
- # # `with_box_refine` of the detector is `True`.
- # tmp_reg_preds += reference
- # else:
- # # When `layer` is 0 and `as_two_stage` of the detector
- # # is `False`, or when `layer` is greater than 0 and
- # # `with_box_refine` of the detector is `False`.
- # assert reference.shape[-1] == 2
- # tmp_reg_preds[..., :2] += reference
- # outputs_coord = tmp_reg_preds.sigmoid()
- # all_layers_outputs_classes.append(outputs_class)
- # all_layers_outputs_coords.append(outputs_coord)
- # all_layers_outputs_classes = torch.stack(all_layers_outputs_classes)
- # all_layers_outputs_coords = torch.stack(all_layers_outputs_coords)
- return hidden_states, references
- def loss_by_feat_single(self,
- cls_scores:Tensor,
- bbox_preds: Tensor,
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict]
- )->Tuple[Tensor]:
- num_imgs = cls_scores.size(0)
- cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
- bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
- cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
- batch_gt_instances, batch_img_metas)
- (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
- num_total_pos, num_total_neg) = cls_reg_targets
- labels = torch.cat(labels_list, 0)
- label_weights = torch.cat(label_weights_list, 0)
- bbox_targets = torch.cat(bbox_targets_list, 0)
- bbox_weights = torch.cat(bbox_weights_list, 0)
- # classification loss
- cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
- # construct weighted avg_factor to match with the official DETR repo
- cls_avg_factor = num_total_pos * 1.0 + \
- num_total_neg * self.bg_cls_weight
- if self.sync_cls_avg_factor:
- cls_avg_factor = reduce_mean(
- cls_scores.new_tensor([cls_avg_factor]))
- cls_avg_factor = max(cls_avg_factor, 1)
- if isinstance(self.loss_cls, VarifocalLoss):
- bg_class_ind = self.num_classes
- pos_inds = ((labels >= 0)
- & (labels < bg_class_ind)).nonzero().squeeze(1)
- cls_iou_targets = label_weights.new_zeros(cls_scores.shape)
- pos_bbox_targets = bbox_targets[pos_inds]
- pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets)
- pos_bbox_pred = bbox_preds.reshape(-1, 4)[pos_inds]
- pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred)
- pos_labels = labels[pos_inds]
- cls_iou_targets[pos_inds, pos_labels] = bbox_overlaps(
- pos_decode_bbox_pred.detach(),
- pos_decode_bbox_targets,
- is_aligned=True)
- loss_cls = self.loss_cls(
- cls_scores, cls_iou_targets, avg_factor=cls_avg_factor)
- else:
- loss_cls = self.loss_cls(
- cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
- # Compute the average number of gt boxes across all gpus, for
- # normalization purposes
- num_total_pos = loss_cls.new_tensor([num_total_pos])
- num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
- # construct factors used for rescale bboxes
- factors = []
- for img_meta, bbox_pred in zip(batch_img_metas, bbox_preds):
- img_h, img_w, = img_meta['img_shape']
- factor = bbox_pred.new_tensor([img_w, img_h, img_w,
- img_h]).unsqueeze(0).repeat(
- bbox_pred.size(0), 1)
- factors.append(factor)
- factors = torch.cat(factors, 0)
- # DETR regress the relative position of boxes (cxcywh) in the image,
- # thus the learning target is normalized by the image size. So here
- # we need to re-scale them for calculating IoU loss
- bbox_preds = bbox_preds.reshape(-1, 4)
- bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
- bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
- # regression IoU loss, defaultly GIoU loss
- loss_iou = self.loss_iou(
- bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
- # regression L1 loss
- loss_bbox = self.loss_bbox(
- bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
- return loss_cls, loss_bbox, loss_iou
- #with denoising, same with loss_feature except using dn target
- def _loss_dn_single(self,
- dn_cls_scores: Tensor,
- dn_bbox_preds: Tensor,
- batch_gt_instances: InstanceList,
- batch_img_metas: List[Dict],
- dn_meta: Dict[str, int]) -> Tuple[Tensor]:
- cls_reg_targets=self.get_dn_targets(batch_gt_instances,batch_img_metas,dn_meta)
- (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
- num_total_pos, num_total_neg) = cls_reg_targets
- labels = torch.cat(labels_list, 0)
- label_weights = torch.cat(label_weights_list, 0)
- bbox_targets = torch.cat(bbox_targets_list, 0)
- bbox_weights=torch.cat(bbox_weights_list,0)
- #class score shape is dn_cls_scores.size()/self.cls_out_channels,self.cls_out_channels
- cls_scores = dn_cls_scores.reshape(-1, self.cls_out_channels)
- # construct weighted avg_factor to match with the official DETR repo
- cls_avg_factor = \
- num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight
- if self.sync_cls_avg_factor:
- cls_avg_factor = reduce_mean(
- cls_scores.new_tensor([cls_avg_factor]))
- cls_avg_factor = max(cls_avg_factor, 1)
- if len(cls_scores) > 0:
- #RTDETRVarifocalLoss focus more on postive target
- #which it can lead model pay more attention on rarely objects
- #for barcode we may not use it since we only care about one class
- if isinstance(self.loss_cls, VarifocalLoss):
- bg_class_ind = self.num_classes
- pos_inds = ((labels >= 0)
- & (labels < bg_class_ind)).nonzero().squeeze(1)
- cls_iou_targets = label_weights.new_zeros(cls_scores.shape)
- pos_bbox_targets = bbox_targets[pos_inds]
- pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets)
- pos_bbox_pred = dn_bbox_preds.reshape(-1, 4)[pos_inds]
- pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred)
- pos_labels = labels[pos_inds]
- cls_iou_targets[pos_inds, pos_labels] = bbox_overlaps(
- pos_decode_bbox_pred.detach(),
- pos_decode_bbox_targets,
- is_aligned=True)
- loss_cls = self.loss_cls(
- cls_scores, cls_iou_targets, avg_factor=cls_avg_factor)
- else:
- loss_cls = self.loss_cls(
- cls_scores,
- labels,
- label_weights,
- avg_factor=cls_avg_factor)
- else:
- loss_cls = torch.zeros(
- 1, dtype=cls_scores.dtype, device=cls_scores.device)
- # Compute the average number of gt boxes across all gpus, for
- # normalization purposes
- num_total_pos = loss_cls.new_tensor([num_total_pos])
- num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
- # construct factors used for rescale bboxes
- factors = []
- for img_meta, bbox_pred in zip(batch_img_metas, dn_bbox_preds):
- img_h, img_w = img_meta['img_shape']
- factor = bbox_pred.new_tensor([img_w, img_h, img_w,
- img_h]).unsqueeze(0).repeat(
- bbox_pred.size(0), 1)
- factors.append(factor)
- factors = torch.cat(factors)
- # DETR regress the relative position of boxes (cxcywh) in the image,
- # thus the learning target is normalized by the image size. So here
- # we need to re-scale them for calculating IoU loss
- bbox_preds = dn_bbox_preds.reshape(-1, 4)
- bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
- bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
- # regression IoU loss, defaultly GIoU loss
- loss_iou = self.loss_iou(
- bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
- # regression L1 loss
- loss_bbox = self.loss_bbox(
- bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
- return loss_cls, loss_bbox, loss_iou
|