sort_tracker.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import numpy as np
  4. import torch
  5. from mmengine.structures import InstanceData
  6. try:
  7. import motmetrics
  8. from motmetrics.lap import linear_sum_assignment
  9. except ImportError:
  10. motmetrics = None
  11. from torch import Tensor
  12. from mmdet.registry import MODELS, TASK_UTILS
  13. from mmdet.structures import DetDataSample
  14. from mmdet.structures.bbox import bbox_overlaps, bbox_xyxy_to_cxcyah
  15. from mmdet.utils import OptConfigType
  16. from ..utils import imrenormalize
  17. from .base_tracker import BaseTracker
  18. @MODELS.register_module()
  19. class SORTTracker(BaseTracker):
  20. """Tracker for SORT/DeepSORT.
  21. Args:
  22. obj_score_thr (float, optional): Threshold to filter the objects.
  23. Defaults to 0.3.
  24. motion (dict): Configuration of motion. Defaults to None.
  25. reid (dict, optional): Configuration for the ReID model.
  26. - num_samples (int, optional): Number of samples to calculate the
  27. feature embeddings of a track. Default to 10.
  28. - image_scale (tuple, optional): Input scale of the ReID model.
  29. Default to (256, 128).
  30. - img_norm_cfg (dict, optional): Configuration to normalize the
  31. input. Default to None.
  32. - match_score_thr (float, optional): Similarity threshold for the
  33. matching process. Default to 2.0.
  34. match_iou_thr (float, optional): Threshold of the IoU matching process.
  35. Defaults to 0.7.
  36. num_tentatives (int, optional): Number of continuous frames to confirm
  37. a track. Defaults to 3.
  38. """
  39. def __init__(self,
  40. motion: Optional[dict] = None,
  41. obj_score_thr: float = 0.3,
  42. reid: dict = dict(
  43. num_samples=10,
  44. img_scale=(256, 128),
  45. img_norm_cfg=None,
  46. match_score_thr=2.0),
  47. match_iou_thr: float = 0.7,
  48. num_tentatives: int = 3,
  49. **kwargs):
  50. if motmetrics is None:
  51. raise RuntimeError('motmetrics is not installed,\
  52. please install it by: pip install motmetrics')
  53. super().__init__(**kwargs)
  54. if motion is not None:
  55. self.motion = TASK_UTILS.build(motion)
  56. assert self.motion is not None, 'SORT/Deep SORT need KalmanFilter'
  57. self.obj_score_thr = obj_score_thr
  58. self.reid = reid
  59. self.match_iou_thr = match_iou_thr
  60. self.num_tentatives = num_tentatives
  61. @property
  62. def confirmed_ids(self) -> List:
  63. """Confirmed ids in the tracker."""
  64. ids = [id for id, track in self.tracks.items() if not track.tentative]
  65. return ids
  66. def init_track(self, id: int, obj: Tuple[Tensor]) -> None:
  67. """Initialize a track."""
  68. super().init_track(id, obj)
  69. self.tracks[id].tentative = True
  70. bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
  71. assert bbox.ndim == 2 and bbox.shape[0] == 1
  72. bbox = bbox.squeeze(0).cpu().numpy()
  73. self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate(
  74. bbox)
  75. def update_track(self, id: int, obj: Tuple[Tensor]) -> None:
  76. """Update a track."""
  77. super().update_track(id, obj)
  78. if self.tracks[id].tentative:
  79. if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
  80. self.tracks[id].tentative = False
  81. bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
  82. assert bbox.ndim == 2 and bbox.shape[0] == 1
  83. bbox = bbox.squeeze(0).cpu().numpy()
  84. self.tracks[id].mean, self.tracks[id].covariance = self.kf.update(
  85. self.tracks[id].mean, self.tracks[id].covariance, bbox)
  86. def pop_invalid_tracks(self, frame_id: int) -> None:
  87. """Pop out invalid tracks."""
  88. invalid_ids = []
  89. for k, v in self.tracks.items():
  90. # case1: disappeared frames >= self.num_frames_retrain
  91. case1 = frame_id - v['frame_ids'][-1] >= self.num_frames_retain
  92. # case2: tentative tracks but not matched in this frame
  93. case2 = v.tentative and v['frame_ids'][-1] != frame_id
  94. if case1 or case2:
  95. invalid_ids.append(k)
  96. for invalid_id in invalid_ids:
  97. self.tracks.pop(invalid_id)
  98. def track(self,
  99. model: torch.nn.Module,
  100. img: Tensor,
  101. data_sample: DetDataSample,
  102. data_preprocessor: OptConfigType = None,
  103. rescale: bool = False,
  104. **kwargs) -> InstanceData:
  105. """Tracking forward function.
  106. Args:
  107. model (nn.Module): MOT model.
  108. img (Tensor): of shape (T, C, H, W) encoding input image.
  109. Typically these should be mean centered and std scaled.
  110. The T denotes the number of key images and usually is 1 in
  111. SORT method.
  112. data_sample (:obj:`TrackDataSample`): The data sample.
  113. It includes information such as `pred_det_instances`.
  114. data_preprocessor (dict or ConfigDict, optional): The pre-process
  115. config of :class:`TrackDataPreprocessor`. it usually includes,
  116. ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
  117. rescale (bool, optional): If True, the bounding boxes should be
  118. rescaled to fit the original scale of the image. Defaults to
  119. False.
  120. Returns:
  121. :obj:`InstanceData`: Tracking results of the input images.
  122. Each InstanceData usually contains ``bboxes``, ``labels``,
  123. ``scores`` and ``instances_id``.
  124. """
  125. metainfo = data_sample.metainfo
  126. bboxes = data_sample.pred_instances.bboxes
  127. labels = data_sample.pred_instances.labels
  128. scores = data_sample.pred_instances.scores
  129. frame_id = metainfo.get('frame_id', -1)
  130. if frame_id == 0:
  131. self.reset()
  132. if not hasattr(self, 'kf'):
  133. self.kf = self.motion
  134. if self.with_reid:
  135. if self.reid.get('img_norm_cfg', False):
  136. img_norm_cfg = dict(
  137. mean=data_preprocessor['mean'],
  138. std=data_preprocessor['std'],
  139. to_bgr=data_preprocessor['rgb_to_bgr'])
  140. reid_img = imrenormalize(img, img_norm_cfg,
  141. self.reid['img_norm_cfg'])
  142. else:
  143. reid_img = img.clone()
  144. valid_inds = scores > self.obj_score_thr
  145. bboxes = bboxes[valid_inds]
  146. labels = labels[valid_inds]
  147. scores = scores[valid_inds]
  148. if self.empty or bboxes.size(0) == 0:
  149. num_new_tracks = bboxes.size(0)
  150. ids = torch.arange(
  151. self.num_tracks,
  152. self.num_tracks + num_new_tracks,
  153. dtype=torch.long).to(bboxes.device)
  154. self.num_tracks += num_new_tracks
  155. if self.with_reid:
  156. crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(),
  157. rescale)
  158. if crops.size(0) > 0:
  159. embeds = model.reid(crops, mode='tensor')
  160. else:
  161. embeds = crops.new_zeros((0, model.reid.head.out_channels))
  162. else:
  163. ids = torch.full((bboxes.size(0), ), -1,
  164. dtype=torch.long).to(bboxes.device)
  165. # motion
  166. self.tracks, costs = self.motion.track(self.tracks,
  167. bbox_xyxy_to_cxcyah(bboxes))
  168. active_ids = self.confirmed_ids
  169. if self.with_reid:
  170. crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(),
  171. rescale)
  172. embeds = model.reid(crops, mode='tensor')
  173. # reid
  174. if len(active_ids) > 0:
  175. track_embeds = self.get(
  176. 'embeds',
  177. active_ids,
  178. self.reid.get('num_samples', None),
  179. behavior='mean')
  180. reid_dists = torch.cdist(track_embeds, embeds)
  181. # support multi-class association
  182. track_labels = torch.tensor([
  183. self.tracks[id]['labels'][-1] for id in active_ids
  184. ]).to(bboxes.device)
  185. cate_match = labels[None, :] == track_labels[:, None]
  186. cate_cost = (1 - cate_match.int()) * 1e6
  187. reid_dists = (reid_dists + cate_cost).cpu().numpy()
  188. valid_inds = [list(self.ids).index(_) for _ in active_ids]
  189. reid_dists[~np.isfinite(costs[valid_inds, :])] = np.nan
  190. row, col = linear_sum_assignment(reid_dists)
  191. for r, c in zip(row, col):
  192. dist = reid_dists[r, c]
  193. if not np.isfinite(dist):
  194. continue
  195. if dist <= self.reid['match_score_thr']:
  196. ids[c] = active_ids[r]
  197. active_ids = [
  198. id for id in self.ids if id not in ids
  199. and self.tracks[id].frame_ids[-1] == frame_id - 1
  200. ]
  201. if len(active_ids) > 0:
  202. active_dets = torch.nonzero(ids == -1).squeeze(1)
  203. track_bboxes = self.get('bboxes', active_ids)
  204. ious = bbox_overlaps(track_bboxes, bboxes[active_dets])
  205. # support multi-class association
  206. track_labels = torch.tensor([
  207. self.tracks[id]['labels'][-1] for id in active_ids
  208. ]).to(bboxes.device)
  209. cate_match = labels[None, active_dets] == track_labels[:, None]
  210. cate_cost = (1 - cate_match.int()) * 1e6
  211. dists = (1 - ious + cate_cost).cpu().numpy()
  212. row, col = linear_sum_assignment(dists)
  213. for r, c in zip(row, col):
  214. dist = dists[r, c]
  215. if dist < 1 - self.match_iou_thr:
  216. ids[active_dets[c]] = active_ids[r]
  217. new_track_inds = ids == -1
  218. ids[new_track_inds] = torch.arange(
  219. self.num_tracks,
  220. self.num_tracks + new_track_inds.sum(),
  221. dtype=torch.long).to(bboxes.device)
  222. self.num_tracks += new_track_inds.sum()
  223. self.update(
  224. ids=ids,
  225. bboxes=bboxes,
  226. scores=scores,
  227. labels=labels,
  228. embeds=embeds if self.with_reid else None,
  229. frame_ids=frame_id)
  230. # update pred_track_instances
  231. pred_track_instances = InstanceData()
  232. pred_track_instances.bboxes = bboxes
  233. pred_track_instances.labels = labels
  234. pred_track_instances.scores = scores
  235. pred_track_instances.instances_id = ids
  236. return pred_track_instances