import numpy as np import math from typing import Tuple,Dict from torch import Tensor, nn from mmdet.registry import MODELS from mmdet.utils import OptConfigType,ConfigType from mmdet.structures import OptSampleList from .base_detr import DetectionTransformer from . import RtDetrTransformer @MODELS.register_module() class RtDetr(DetectionTransformer): def _init_layers(self): #we use hybird encoder in the neck and put #simple encoder inside decoder with postional encoding self.encoder=None self.decoder=RtDetrTransformer(**self.decoder) self.positional_encoding=None def pre_transformer( self, mlvl_feats: Tuple[Tensor], batch_data_samples: OptSampleList = None) -> None: print("pre transformer implemented in transformer") return None def forward_encoder(self, feat: Tensor, feat_mask: Tensor, feat_pos: Tensor, **kwargs) -> None: print("forward_encoder not use in rtdetr") return None def pre_decoder(self, memory: Tensor, **kwargs) -> None: print("pre decoder implemented in rtdetr transformer") return None def forward_decoder(self,memory: Tensor, **kwargs) -> None: print("foward decoder implemented in rtdetr transformer") return None def forward_transformer(self, img_feats: Tuple[Tensor], batch_data_samples: OptSampleList = None) -> Dict: """Forward process of Transformer, which includes four steps: 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'. We summarized the parameters flow of the existing DETR-like detector, which can be illustrated as follow: .. code:: text img_feats from hybridencoder & batch_data_samples | V +-----------------+ | | forward_transformer | +-----------------+ | | | V V head_inputs_dict Args: img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each feature map has shape (bs, dim, H, W). batch_data_samples (list[:obj:`DetDataSample`], optional): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Defaults to None. Returns: dict: The dictionary of bbox_head function inputs, which always includes the `hidden_states` of the decoder output and may contain `references` including the initial and intermediate references. """ # def forward(self,feats,pad_mask=None,gt_meta=None): # return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, # dn_meta) out_bboxes,\ out_logits,\ enc_topk_bboxes,\ enc_topk_logits,\ dn_meta = self.decoder( feats=img_feats, pad_mask=None, gt_meta=batch_data_samples) if self.training: head_input_dict=dict( enc_outputs_coord=enc_topk_bboxes, enc_outputs_class=enc_topk_logits, dn_meta=dn_meta\ ) else: head_input_dict=dict() decoder_outputs_dict = dict( hidden_states=out_logits, references=out_bboxes) head_input_dict.update(decoder_outputs_dict) return head_input_dict