123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional
- import numpy as np
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.structures import TrackSampleList
- from mmdet.utils import OptConfigType
- from .deep_sort import DeepSORT
- @MODELS.register_module()
- class StrongSORT(DeepSORT):
- """StrongSORT: Make DeepSORT Great Again.
- Details can be found at `StrongSORT<https://arxiv.org/abs/2202.13514>`_.
- Args:
- detector (dict): Configuration of detector. Defaults to None.
- reid (dict): Configuration of reid. Defaults to None
- tracker (dict): Configuration of tracker. Defaults to None.
- kalman (dict): Configuration of Kalman filter. Defaults to None.
- cmc (dict): Configuration of camera model compensation.
- Defaults to None.
- data_preprocessor (dict or ConfigDict, optional): The pre-process
- config of :class:`TrackDataPreprocessor`. it usually includes,
- ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
- init_cfg (dict or list[dict]): Configuration of initialization.
- Defaults to None.
- """
- def __init__(self,
- detector: Optional[dict] = None,
- reid: Optional[dict] = None,
- cmc: Optional[dict] = None,
- tracker: Optional[dict] = None,
- postprocess_model: Optional[dict] = None,
- data_preprocessor: OptConfigType = None,
- init_cfg: OptConfigType = None):
- super().__init__(detector, reid, tracker, data_preprocessor, init_cfg)
- if cmc is not None:
- self.cmc = TASK_UTILS.build(cmc)
- if postprocess_model is not None:
- self.postprocess_model = TASK_UTILS.build(postprocess_model)
- @property
- def with_cmc(self):
- """bool: whether the framework has a camera model compensation
- model.
- """
- return hasattr(self, 'cmc') and self.cmc is not None
- def predict(self,
- inputs: Tensor,
- data_samples: TrackSampleList,
- rescale: bool = True,
- **kwargs) -> TrackSampleList:
- """Predict results from a video and data samples with post- processing.
- Args:
- inputs (Tensor): of shape (N, T, C, H, W) encoding
- input images. The N denotes batch size.
- The T denotes the number of key frames
- and reference frames.
- data_samples (list[:obj:`TrackDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance`.
- rescale (bool, Optional): If False, then returned bboxes and masks
- will fit the scale of img, otherwise, returned bboxes and masks
- will fit the scale of original image shape. Defaults to True.
- Returns:
- TrackSampleList: List[TrackDataSample]
- Tracking results of the input videos.
- Each DetDataSample usually contains ``pred_track_instances``.
- """
- assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
- assert inputs.size(0) == 1, \
- 'SORT/DeepSORT inference only support ' \
- '1 batch size per gpu for now.'
- assert len(data_samples) == 1, \
- 'SORT/DeepSORT inference only support ' \
- '1 batch size per gpu for now.'
- track_data_sample = data_samples[0]
- video_len = len(track_data_sample)
- video_track_instances = []
- for frame_id in range(video_len):
- img_data_sample = track_data_sample[frame_id]
- single_img = inputs[:, frame_id].contiguous()
- # det_results List[DetDataSample]
- det_results = self.detector.predict(single_img, [img_data_sample])
- assert len(det_results) == 1, 'Batch inference is not supported.'
- pred_track_instances = self.tracker.track(
- model=self,
- img=single_img,
- data_sample=det_results[0],
- data_preprocessor=self.preprocess_cfg,
- rescale=rescale,
- **kwargs)
- for i in range(len(pred_track_instances.instances_id)):
- video_track_instances.append(
- np.array([
- frame_id + 1,
- pred_track_instances.instances_id[i].cpu(),
- pred_track_instances.bboxes[i][0].cpu(),
- pred_track_instances.bboxes[i][1].cpu(),
- (pred_track_instances.bboxes[i][2] -
- pred_track_instances.bboxes[i][0]).cpu(),
- (pred_track_instances.bboxes[i][3] -
- pred_track_instances.bboxes[i][1]).cpu(),
- pred_track_instances.scores[i].cpu()
- ]))
- video_track_instances = np.array(video_track_instances).reshape(-1, 7)
- video_track_instances = self.postprocess_model.forward(
- video_track_instances)
- for frame_id in range(video_len):
- track_data_sample[frame_id].pred_track_instances = \
- InstanceData(bboxes=video_track_instances[
- video_track_instances[:, 0] == frame_id + 1, :])
- return [track_data_sample]
|