base.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. from typing import Dict, List, Tuple, Union
  4. from mmengine.model import BaseModel
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. from mmdet.structures import OptTrackSampleList, TrackSampleList
  8. from mmdet.utils import OptConfigType, OptMultiConfig
  9. @MODELS.register_module()
  10. class BaseMOTModel(BaseModel, metaclass=ABCMeta):
  11. """Base class for multiple object tracking.
  12. Args:
  13. data_preprocessor (dict or ConfigDict, optional): The pre-process
  14. config of :class:`TrackDataPreprocessor`. it usually includes,
  15. ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
  16. init_cfg (dict or list[dict]): Initialization config dict.
  17. """
  18. def __init__(self,
  19. data_preprocessor: OptConfigType = None,
  20. init_cfg: OptMultiConfig = None) -> None:
  21. super().__init__(
  22. data_preprocessor=data_preprocessor, init_cfg=init_cfg)
  23. def freeze_module(self, module: Union[List[str], Tuple[str], str]) -> None:
  24. """Freeze module during training."""
  25. if isinstance(module, str):
  26. modules = [module]
  27. else:
  28. if not (isinstance(module, list) or isinstance(module, tuple)):
  29. raise TypeError('module must be a str or a list.')
  30. else:
  31. modules = module
  32. for module in modules:
  33. m = getattr(self, module)
  34. m.eval()
  35. for param in m.parameters():
  36. param.requires_grad = False
  37. @property
  38. def with_detector(self) -> bool:
  39. """bool: whether the framework has a detector."""
  40. return hasattr(self, 'detector') and self.detector is not None
  41. @property
  42. def with_reid(self) -> bool:
  43. """bool: whether the framework has a reid model."""
  44. return hasattr(self, 'reid') and self.reid is not None
  45. @property
  46. def with_motion(self) -> bool:
  47. """bool: whether the framework has a motion model."""
  48. return hasattr(self, 'motion') and self.motion is not None
  49. @property
  50. def with_track_head(self) -> bool:
  51. """bool: whether the framework has a track_head."""
  52. return hasattr(self, 'track_head') and self.track_head is not None
  53. @property
  54. def with_tracker(self) -> bool:
  55. """bool: whether the framework has a tracker."""
  56. return hasattr(self, 'tracker') and self.tracker is not None
  57. def forward(self,
  58. inputs: Dict[str, Tensor],
  59. data_samples: OptTrackSampleList = None,
  60. mode: str = 'predict',
  61. **kwargs):
  62. """The unified entry for a forward process in both training and test.
  63. The method should accept three modes: "tensor", "predict" and "loss":
  64. - "tensor": Forward the whole network and return tensor or tuple of
  65. tensor without any post-processing, same as a common nn.Module.
  66. - "predict": Forward and return the predictions, which are fully
  67. processed to a list of :obj:`TrackDataSample`.
  68. - "loss": Forward and return a dict of losses according to the given
  69. inputs and data samples.
  70. Note that this method doesn't handle neither back propagation nor
  71. optimizer updating, which are done in the :meth:`train_step`.
  72. Args:
  73. inputs (Dict[str, Tensor]): of shape (N, T, C, H, W)
  74. encoding input images. Typically these should be mean centered
  75. and std scaled. The N denotes batch size. The T denotes the
  76. number of key/reference frames.
  77. - img (Tensor) : The key images.
  78. - ref_img (Tensor): The reference images.
  79. data_samples (list[:obj:`TrackDataSample`], optional): The
  80. annotation data of every samples. Defaults to None.
  81. mode (str): Return what kind of value. Defaults to 'predict'.
  82. Returns:
  83. The return type depends on ``mode``.
  84. - If ``mode="tensor"``, return a tensor or a tuple of tensor.
  85. - If ``mode="predict"``, return a list of :obj:`TrackDataSample`.
  86. - If ``mode="loss"``, return a dict of tensor.
  87. """
  88. if mode == 'loss':
  89. return self.loss(inputs, data_samples, **kwargs)
  90. elif mode == 'predict':
  91. return self.predict(inputs, data_samples, **kwargs)
  92. elif mode == 'tensor':
  93. return self._forward(inputs, data_samples, **kwargs)
  94. else:
  95. raise RuntimeError(f'Invalid mode "{mode}". '
  96. 'Only supports loss, predict and tensor mode')
  97. @abstractmethod
  98. def loss(self, inputs: Dict[str, Tensor], data_samples: TrackSampleList,
  99. **kwargs) -> Union[dict, tuple]:
  100. """Calculate losses from a batch of inputs and data samples."""
  101. pass
  102. @abstractmethod
  103. def predict(self, inputs: Dict[str, Tensor], data_samples: TrackSampleList,
  104. **kwargs) -> TrackSampleList:
  105. """Predict results from a batch of inputs and data samples with post-
  106. processing."""
  107. pass
  108. def _forward(self,
  109. inputs: Dict[str, Tensor],
  110. data_samples: OptTrackSampleList = None,
  111. **kwargs):
  112. """Network forward process. Usually includes backbone, neck and head
  113. forward without any post-processing.
  114. Args:
  115. inputs (Dict[str, Tensor]): of shape (N, T, C, H, W).
  116. data_samples (List[:obj:`TrackDataSample`], optional): The
  117. Data Samples. It usually includes information such as
  118. `gt_instance`.
  119. Returns:
  120. tuple[list]: A tuple of features from ``head`` forward.
  121. """
  122. raise NotImplementedError(
  123. "_forward function (namely 'tensor' mode) is not supported now")