base.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. from typing import Dict, List, Tuple, Union
  4. import torch
  5. from mmengine.model import BaseModel
  6. from torch import Tensor
  7. from mmengine.logging import print_log
  8. from mmdet.structures import DetDataSample, OptSampleList, SampleList
  9. from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig
  10. from ..utils import samplelist_boxtype2tensor
  11. ForwardResults = Union[Dict[str, torch.Tensor], List[DetDataSample],
  12. Tuple[torch.Tensor], torch.Tensor]
  13. class BaseDetector(BaseModel, metaclass=ABCMeta):
  14. """Base class for detectors.
  15. Args:
  16. data_preprocessor (dict or ConfigDict, optional): The pre-process
  17. config of :class:`BaseDataPreprocessor`. it usually includes,
  18. ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
  19. init_cfg (dict or ConfigDict, optional): the config to control the
  20. initialization. Defaults to None.
  21. """
  22. def __init__(self,
  23. data_preprocessor: OptConfigType = None,
  24. init_cfg: OptMultiConfig = None):
  25. super().__init__(
  26. data_preprocessor=data_preprocessor, init_cfg=init_cfg)
  27. #we need a switch deploy for reparametre vgg
  28. def switch_to_deploy(self):
  29. """Switch the sub-modules to deploy mode."""
  30. for name, layer in self.named_modules():
  31. if layer == self:
  32. continue
  33. if callable(getattr(layer, 'switch_to_deploy', None)):
  34. print_log(f'module {name} has been switched to deploy mode',
  35. 'current')
  36. layer.switch_to_deploy(self.test_cfg)
  37. @property
  38. def with_neck(self) -> bool:
  39. """bool: whether the detector has a neck"""
  40. return hasattr(self, 'neck') and self.neck is not None
  41. # TODO: these properties need to be carefully handled
  42. # for both single stage & two stage detectors
  43. @property
  44. def with_shared_head(self) -> bool:
  45. """bool: whether the detector has a shared head in the RoI Head"""
  46. return hasattr(self, 'roi_head') and self.roi_head.with_shared_head
  47. @property
  48. def with_bbox(self) -> bool:
  49. """bool: whether the detector has a bbox head"""
  50. return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox)
  51. or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
  52. @property
  53. def with_mask(self) -> bool:
  54. """bool: whether the detector has a mask head"""
  55. return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
  56. or (hasattr(self, 'mask_head') and self.mask_head is not None))
  57. def forward(self,
  58. inputs: torch.Tensor,
  59. data_samples: OptSampleList = None,
  60. mode: str = 'tensor') -> ForwardResults:
  61. """The unified entry for a forward process in both training and test.
  62. The method should accept three modes: "tensor", "predict" and "loss":
  63. - "tensor": Forward the whole network and return tensor or tuple of
  64. tensor without any post-processing, same as a common nn.Module.
  65. - "predict": Forward and return the predictions, which are fully
  66. processed to a list of :obj:`DetDataSample`.
  67. - "loss": Forward and return a dict of losses according to the given
  68. inputs and data samples.
  69. Note that this method doesn't handle either back propagation or
  70. parameter update, which are supposed to be done in :meth:`train_step`.
  71. Args:
  72. inputs (torch.Tensor): The input tensor with shape
  73. (N, C, ...) in general.
  74. data_samples (list[:obj:`DetDataSample`], optional): A batch of
  75. data samples that contain annotations and predictions.
  76. Defaults to None.
  77. mode (str): Return what kind of value. Defaults to 'tensor'.
  78. Returns:
  79. The return type depends on ``mode``.
  80. - If ``mode="tensor"``, return a tensor or a tuple of tensor.
  81. - If ``mode="predict"``, return a list of :obj:`DetDataSample`.
  82. - If ``mode="loss"``, return a dict of tensor.
  83. """
  84. if mode == 'loss':
  85. return self.loss(inputs, data_samples)
  86. elif mode == 'predict':
  87. return self.predict(inputs, data_samples)
  88. elif mode == 'tensor':
  89. return self._forward(inputs, data_samples)
  90. else:
  91. raise RuntimeError(f'Invalid mode "{mode}". '
  92. 'Only supports loss, predict and tensor mode')
  93. @abstractmethod
  94. def loss(self, batch_inputs: Tensor,
  95. batch_data_samples: SampleList) -> Union[dict, tuple]:
  96. """Calculate losses from a batch of inputs and data samples."""
  97. pass
  98. @abstractmethod
  99. def predict(self, batch_inputs: Tensor,
  100. batch_data_samples: SampleList) -> SampleList:
  101. """Predict results from a batch of inputs and data samples with post-
  102. processing."""
  103. pass
  104. @abstractmethod
  105. def _forward(self,
  106. batch_inputs: Tensor,
  107. batch_data_samples: OptSampleList = None):
  108. """Network forward process.
  109. Usually includes backbone, neck and head forward without any post-
  110. processing.
  111. """
  112. pass
  113. @abstractmethod
  114. def extract_feat(self, batch_inputs: Tensor):
  115. """Extract features from images."""
  116. pass
  117. def add_pred_to_datasample(self, data_samples: SampleList,
  118. results_list: InstanceList) -> SampleList:
  119. """Add predictions to `DetDataSample`.
  120. Args:
  121. data_samples (list[:obj:`DetDataSample`], optional): A batch of
  122. data samples that contain annotations and predictions.
  123. results_list (list[:obj:`InstanceData`]): Detection results of
  124. each image.
  125. Returns:
  126. list[:obj:`DetDataSample`]: Detection results of the
  127. input images. Each DetDataSample usually contain
  128. 'pred_instances'. And the ``pred_instances`` usually
  129. contains following keys.
  130. - scores (Tensor): Classification scores, has a shape
  131. (num_instance, )
  132. - labels (Tensor): Labels of bboxes, has a shape
  133. (num_instances, ).
  134. - bboxes (Tensor): Has a shape (num_instances, 4),
  135. the last dimension 4 arrange as (x1, y1, x2, y2).
  136. """
  137. for data_sample, pred_instances in zip(data_samples, results_list):
  138. data_sample.pred_instances = pred_instances
  139. samplelist_boxtype2tensor(data_samples)
  140. return data_samples