deep_sort.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. from torch import Tensor
  4. from mmdet.registry import MODELS
  5. from mmdet.structures import TrackSampleList
  6. from mmdet.utils import OptConfigType
  7. from .base import BaseMOTModel
  8. @MODELS.register_module()
  9. class DeepSORT(BaseMOTModel):
  10. """Simple online and realtime tracking with a deep association metric.
  11. Details can be found at `DeepSORT<https://arxiv.org/abs/1703.07402>`_.
  12. Args:
  13. detector (dict): Configuration of detector. Defaults to None.
  14. reid (dict): Configuration of reid. Defaults to None
  15. tracker (dict): Configuration of tracker. Defaults to None.
  16. data_preprocessor (dict or ConfigDict, optional): The pre-process
  17. config of :class:`TrackDataPreprocessor`. it usually includes,
  18. ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
  19. init_cfg (dict or list[dict]): Configuration of initialization.
  20. Defaults to None.
  21. """
  22. def __init__(self,
  23. detector: Optional[dict] = None,
  24. reid: Optional[dict] = None,
  25. tracker: Optional[dict] = None,
  26. data_preprocessor: OptConfigType = None,
  27. init_cfg: OptConfigType = None):
  28. super().__init__(data_preprocessor, init_cfg)
  29. if detector is not None:
  30. self.detector = MODELS.build(detector)
  31. if reid is not None:
  32. self.reid = MODELS.build(reid)
  33. if tracker is not None:
  34. self.tracker = MODELS.build(tracker)
  35. self.preprocess_cfg = data_preprocessor
  36. def loss(self, inputs: Tensor, data_samples: TrackSampleList,
  37. **kwargs) -> dict:
  38. """Calculate losses from a batch of inputs and data samples."""
  39. raise NotImplementedError(
  40. 'Please train `detector` and `reid` models firstly, then \
  41. inference with SORT/DeepSORT.')
  42. def predict(self,
  43. inputs: Tensor,
  44. data_samples: TrackSampleList,
  45. rescale: bool = True,
  46. **kwargs) -> TrackSampleList:
  47. """Predict results from a video and data samples with post- processing.
  48. Args:
  49. inputs (Tensor): of shape (N, T, C, H, W) encoding
  50. input images. The N denotes batch size.
  51. The T denotes the number of key frames
  52. and reference frames.
  53. data_samples (list[:obj:`TrackDataSample`]): The batch
  54. data samples. It usually includes information such
  55. as `gt_instance`.
  56. rescale (bool, Optional): If False, then returned bboxes and masks
  57. will fit the scale of img, otherwise, returned bboxes and masks
  58. will fit the scale of original image shape. Defaults to True.
  59. Returns:
  60. TrackSampleList: List[TrackDataSample]
  61. Tracking results of the input videos.
  62. Each DetDataSample usually contains ``pred_track_instances``.
  63. """
  64. assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
  65. assert inputs.size(0) == 1, \
  66. 'SORT/DeepSORT inference only support ' \
  67. '1 batch size per gpu for now.'
  68. assert len(data_samples) == 1, \
  69. 'SORT/DeepSORT inference only support ' \
  70. '1 batch size per gpu for now.'
  71. track_data_sample = data_samples[0]
  72. video_len = len(track_data_sample)
  73. if track_data_sample[0].frame_id == 0:
  74. self.tracker.reset()
  75. for frame_id in range(video_len):
  76. img_data_sample = track_data_sample[frame_id]
  77. single_img = inputs[:, frame_id].contiguous()
  78. # det_results List[DetDataSample]
  79. det_results = self.detector.predict(single_img, [img_data_sample])
  80. assert len(det_results) == 1, 'Batch inference is not supported.'
  81. pred_track_instances = self.tracker.track(
  82. model=self,
  83. img=single_img,
  84. feats=None,
  85. data_sample=det_results[0],
  86. data_preprocessor=self.preprocess_cfg,
  87. rescale=rescale,
  88. **kwargs)
  89. img_data_sample.pred_track_instances = pred_track_instances
  90. return [track_data_sample]