bytetrack.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, Optional
  3. from torch import Tensor
  4. from mmdet.registry import MODELS
  5. from mmdet.structures import SampleList, TrackSampleList
  6. from mmdet.utils import OptConfigType, OptMultiConfig
  7. from .base import BaseMOTModel
  8. @MODELS.register_module()
  9. class ByteTrack(BaseMOTModel):
  10. """ByteTrack: Multi-Object Tracking by Associating Every Detection Box.
  11. This multi object tracker is the implementation of `ByteTrack
  12. <https://arxiv.org/abs/2110.06864>`_.
  13. Args:
  14. detector (dict): Configuration of detector. Defaults to None.
  15. tracker (dict): Configuration of tracker. Defaults to None.
  16. data_preprocessor (dict or ConfigDict, optional): The pre-process
  17. config of :class:`TrackDataPreprocessor`. it usually includes,
  18. ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
  19. init_cfg (dict or list[dict]): Configuration of initialization.
  20. Defaults to None.
  21. """
  22. def __init__(self,
  23. detector: Optional[dict] = None,
  24. tracker: Optional[dict] = None,
  25. data_preprocessor: OptConfigType = None,
  26. init_cfg: OptMultiConfig = None):
  27. super().__init__(data_preprocessor, init_cfg)
  28. if detector is not None:
  29. self.detector = MODELS.build(detector)
  30. if tracker is not None:
  31. self.tracker = MODELS.build(tracker)
  32. def loss(self, inputs: Tensor, data_samples: SampleList, **kwargs) -> dict:
  33. """Calculate losses from a batch of inputs and data samples.
  34. Args:
  35. inputs (Tensor): of shape (N, C, H, W) encoding
  36. input images. Typically these should be mean centered and std
  37. scaled. The N denotes batch size
  38. data_samples (list[:obj:`DetDataSample`]): The batch
  39. data samples. It usually includes information such
  40. as `gt_instance`.
  41. Returns:
  42. dict: A dictionary of loss components.
  43. """
  44. return self.detector.loss(inputs, data_samples, **kwargs)
  45. def predict(self, inputs: Dict[str, Tensor], data_samples: TrackSampleList,
  46. **kwargs) -> TrackSampleList:
  47. """Predict results from a video and data samples with post-processing.
  48. Args:
  49. inputs (Tensor): of shape (N, T, C, H, W) encoding
  50. input images. The N denotes batch size.
  51. The T denotes the number of frames in a video.
  52. data_samples (list[:obj:`TrackDataSample`]): The batch
  53. data samples. It usually includes information such
  54. as `video_data_samples`.
  55. Returns:
  56. TrackSampleList: Tracking results of the inputs.
  57. """
  58. assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
  59. assert inputs.size(0) == 1, \
  60. 'Bytetrack inference only support ' \
  61. '1 batch size per gpu for now.'
  62. assert len(data_samples) == 1, \
  63. 'Bytetrack inference only support 1 batch size per gpu for now.'
  64. track_data_sample = data_samples[0]
  65. video_len = len(track_data_sample)
  66. for frame_id in range(video_len):
  67. img_data_sample = track_data_sample[frame_id]
  68. single_img = inputs[:, frame_id].contiguous()
  69. # det_results List[DetDataSample]
  70. det_results = self.detector.predict(single_img, [img_data_sample])
  71. assert len(det_results) == 1, 'Batch inference is not supported.'
  72. pred_track_instances = self.tracker.track(
  73. data_sample=det_results[0], **kwargs)
  74. img_data_sample.pred_track_instances = pred_track_instances
  75. return [track_data_sample]