1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- import numpy as np
- import math
- from typing import Tuple,Dict
- from torch import Tensor, nn
- from mmdet.registry import MODELS
- from mmdet.utils import OptConfigType,ConfigType
- from mmdet.structures import OptSampleList
- from .base_detr import DetectionTransformer
- from . import RtDetrTransformer
- @MODELS.register_module()
- class RtDetr(DetectionTransformer):
- def _init_layers(self):
- #we use hybird encoder in the neck and put
- #simple encoder inside decoder with postional encoding
- self.encoder=None
- self.decoder=RtDetrTransformer(**self.decoder)
- self.positional_encoding=None
- def pre_transformer(
- self,
- mlvl_feats: Tuple[Tensor],
- batch_data_samples: OptSampleList = None) -> None:
- print("pre transformer implemented in transformer")
- return None
- def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
- feat_pos: Tensor, **kwargs) -> None:
- print("forward_encoder not use in rtdetr")
- return None
- def pre_decoder(self, memory: Tensor, **kwargs) -> None:
- print("pre decoder implemented in rtdetr transformer")
- return None
- def forward_decoder(self,memory: Tensor,
- **kwargs) -> None:
- print("foward decoder implemented in rtdetr transformer")
- return None
-
- def forward_transformer(self,
- img_feats: Tuple[Tensor],
- batch_data_samples: OptSampleList = None) -> Dict:
- """Forward process of Transformer, which includes four steps:
- 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'. We
- summarized the parameters flow of the existing DETR-like detector,
- which can be illustrated as follow:
- .. code:: text
- img_feats from hybridencoder & batch_data_samples
- |
- V
- +-----------------+ |
- | forward_transformer |
- +-----------------+ |
- | |
- V V
- head_inputs_dict
- Args:
- img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
- feature map has shape (bs, dim, H, W).
- batch_data_samples (list[:obj:`DetDataSample`], optional): The
- batch data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Defaults to None.
- Returns:
- dict: The dictionary of bbox_head function inputs, which always
- includes the `hidden_states` of the decoder output and may contain
- `references` including the initial and intermediate references.
- """
- # def forward(self,feats,pad_mask=None,gt_meta=None):
- # return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits,
- # dn_meta)
- out_bboxes,\
- out_logits,\
- enc_topk_bboxes,\
- enc_topk_logits,\
- dn_meta = self.decoder(
- feats=img_feats,
- pad_mask=None,
- gt_meta=batch_data_samples)
- if self.training:
- head_input_dict=dict(
- enc_outputs_coord=enc_topk_bboxes,
- enc_outputs_class=enc_topk_logits,
- dn_meta=dn_meta\
- )
- else:
- head_input_dict=dict()
-
- decoder_outputs_dict = dict(
- hidden_states=out_logits, references=out_bboxes)
- head_input_dict.update(decoder_outputs_dict)
- return head_input_dict
|