rtdetr.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import numpy as np
  2. import math
  3. from typing import Tuple,Dict
  4. from torch import Tensor, nn
  5. from mmdet.registry import MODELS
  6. from mmdet.utils import OptConfigType,ConfigType
  7. from mmdet.structures import OptSampleList
  8. from .base_detr import DetectionTransformer
  9. from . import RtDetrTransformer
  10. @MODELS.register_module()
  11. class RtDetr(DetectionTransformer):
  12. def _init_layers(self):
  13. #we use hybird encoder in the neck and put
  14. #simple encoder inside decoder with postional encoding
  15. self.encoder=None
  16. self.decoder=RtDetrTransformer(**self.decoder)
  17. self.positional_encoding=None
  18. def pre_transformer(
  19. self,
  20. mlvl_feats: Tuple[Tensor],
  21. batch_data_samples: OptSampleList = None) -> None:
  22. print("pre transformer implemented in transformer")
  23. return None
  24. def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
  25. feat_pos: Tensor, **kwargs) -> None:
  26. print("forward_encoder not use in rtdetr")
  27. return None
  28. def pre_decoder(self, memory: Tensor, **kwargs) -> None:
  29. print("pre decoder implemented in rtdetr transformer")
  30. return None
  31. def forward_decoder(self,memory: Tensor,
  32. **kwargs) -> None:
  33. print("foward decoder implemented in rtdetr transformer")
  34. return None
  35. def forward_transformer(self,
  36. img_feats: Tuple[Tensor],
  37. batch_data_samples: OptSampleList = None) -> Dict:
  38. """Forward process of Transformer, which includes four steps:
  39. 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'. We
  40. summarized the parameters flow of the existing DETR-like detector,
  41. which can be illustrated as follow:
  42. .. code:: text
  43. img_feats from hybridencoder & batch_data_samples
  44. |
  45. V
  46. +-----------------+ |
  47. | forward_transformer |
  48. +-----------------+ |
  49. | |
  50. V V
  51. head_inputs_dict
  52. Args:
  53. img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
  54. feature map has shape (bs, dim, H, W).
  55. batch_data_samples (list[:obj:`DetDataSample`], optional): The
  56. batch data samples. It usually includes information such
  57. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  58. Defaults to None.
  59. Returns:
  60. dict: The dictionary of bbox_head function inputs, which always
  61. includes the `hidden_states` of the decoder output and may contain
  62. `references` including the initial and intermediate references.
  63. """
  64. # def forward(self,feats,pad_mask=None,gt_meta=None):
  65. # return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits,
  66. # dn_meta)
  67. out_bboxes,\
  68. out_logits,\
  69. enc_topk_bboxes,\
  70. enc_topk_logits,\
  71. dn_meta = self.decoder(
  72. feats=img_feats,
  73. pad_mask=None,
  74. gt_meta=batch_data_samples)
  75. if self.training:
  76. head_input_dict=dict(
  77. enc_outputs_coord=enc_topk_bboxes,
  78. enc_outputs_class=enc_topk_logits,
  79. dn_meta=dn_meta\
  80. )
  81. else:
  82. head_input_dict=dict()
  83. decoder_outputs_dict = dict(
  84. hidden_states=out_logits, references=out_bboxes)
  85. head_input_dict.update(decoder_outputs_dict)
  86. return head_input_dict