strongsort_tracker.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple
  3. import numpy as np
  4. import torch
  5. from mmengine.structures import InstanceData
  6. try:
  7. import motmetrics
  8. from motmetrics.lap import linear_sum_assignment
  9. except ImportError:
  10. motmetrics = None
  11. from torch import Tensor
  12. from mmdet.models.utils import imrenormalize
  13. from mmdet.registry import MODELS
  14. from mmdet.structures import TrackDataSample
  15. from mmdet.structures.bbox import bbox_overlaps, bbox_xyxy_to_cxcyah
  16. from mmdet.utils import OptConfigType
  17. from .sort_tracker import SORTTracker
  18. def cosine_distance(x: Tensor, y: Tensor) -> np.ndarray:
  19. """compute the cosine distance.
  20. Args:
  21. x (Tensor): embeddings with shape (N,C).
  22. y (Tensor): embeddings with shape (M,C).
  23. Returns:
  24. ndarray: cosine distance with shape (N,M).
  25. """
  26. x = x.cpu().numpy()
  27. y = y.cpu().numpy()
  28. x = x / np.linalg.norm(x, axis=1, keepdims=True)
  29. y = y / np.linalg.norm(y, axis=1, keepdims=True)
  30. dists = 1. - np.dot(x, y.T)
  31. return dists
  32. @MODELS.register_module()
  33. class StrongSORTTracker(SORTTracker):
  34. """Tracker for StrongSORT.
  35. Args:
  36. obj_score_thr (float, optional): Threshold to filter the objects.
  37. Defaults to 0.6.
  38. motion (dict): Configuration of motion. Defaults to None.
  39. reid (dict, optional): Configuration for the ReID model.
  40. - num_samples (int, optional): Number of samples to calculate the
  41. feature embeddings of a track. Default to None.
  42. - image_scale (tuple, optional): Input scale of the ReID model.
  43. Default to (256, 128).
  44. - img_norm_cfg (dict, optional): Configuration to normalize the
  45. input. Default to None.
  46. - match_score_thr (float, optional): Similarity threshold for the
  47. matching process. Default to 0.3.
  48. - motion_weight (float, optional): the weight of the motion cost.
  49. Defaults to 0.02.
  50. match_iou_thr (float, optional): Threshold of the IoU matching process.
  51. Defaults to 0.7.
  52. num_tentatives (int, optional): Number of continuous frames to confirm
  53. a track. Defaults to 2.
  54. """
  55. def __init__(self,
  56. motion: Optional[dict] = None,
  57. obj_score_thr: float = 0.6,
  58. reid: dict = dict(
  59. num_samples=None,
  60. img_scale=(256, 128),
  61. img_norm_cfg=None,
  62. match_score_thr=0.3,
  63. motion_weight=0.02),
  64. match_iou_thr: float = 0.7,
  65. num_tentatives: int = 2,
  66. **kwargs):
  67. if motmetrics is None:
  68. raise RuntimeError('motmetrics is not installed,\
  69. please install it by: pip install motmetrics')
  70. super().__init__(motion, obj_score_thr, reid, match_iou_thr,
  71. num_tentatives, **kwargs)
  72. def update_track(self, id: int, obj: Tuple[Tensor]) -> None:
  73. """Update a track."""
  74. for k, v in zip(self.memo_items, obj):
  75. v = v[None]
  76. if self.momentums is not None and k in self.momentums:
  77. m = self.momentums[k]
  78. self.tracks[id][k] = (1 - m) * self.tracks[id][k] + m * v
  79. else:
  80. self.tracks[id][k].append(v)
  81. if self.tracks[id].tentative:
  82. if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
  83. self.tracks[id].tentative = False
  84. bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
  85. assert bbox.ndim == 2 and bbox.shape[0] == 1
  86. bbox = bbox.squeeze(0).cpu().numpy()
  87. score = float(self.tracks[id].scores[-1].cpu())
  88. self.tracks[id].mean, self.tracks[id].covariance = self.kf.update(
  89. self.tracks[id].mean, self.tracks[id].covariance, bbox, score)
  90. def track(self,
  91. model: torch.nn.Module,
  92. img: Tensor,
  93. data_sample: TrackDataSample,
  94. data_preprocessor: OptConfigType = None,
  95. rescale: bool = False,
  96. **kwargs) -> InstanceData:
  97. """Tracking forward function.
  98. Args:
  99. model (nn.Module): MOT model.
  100. img (Tensor): of shape (T, C, H, W) encoding input image.
  101. Typically these should be mean centered and std scaled.
  102. The T denotes the number of key images and usually is 1 in
  103. SORT method.
  104. feats (list[Tensor]): Multi level feature maps of `img`.
  105. data_sample (:obj:`TrackDataSample`): The data sample.
  106. It includes information such as `pred_det_instances`.
  107. data_preprocessor (dict or ConfigDict, optional): The pre-process
  108. config of :class:`TrackDataPreprocessor`. it usually includes,
  109. ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
  110. rescale (bool, optional): If True, the bounding boxes should be
  111. rescaled to fit the original scale of the image. Defaults to
  112. False.
  113. Returns:
  114. :obj:`InstanceData`: Tracking results of the input images.
  115. Each InstanceData usually contains ``bboxes``, ``labels``,
  116. ``scores`` and ``instances_id``.
  117. """
  118. metainfo = data_sample.metainfo
  119. bboxes = data_sample.pred_instances.bboxes
  120. labels = data_sample.pred_instances.labels
  121. scores = data_sample.pred_instances.scores
  122. frame_id = metainfo.get('frame_id', -1)
  123. if frame_id == 0:
  124. self.reset()
  125. if not hasattr(self, 'kf'):
  126. self.kf = self.motion
  127. if self.with_reid:
  128. if self.reid.get('img_norm_cfg', False):
  129. img_norm_cfg = dict(
  130. mean=data_preprocessor.get('mean', [0, 0, 0]),
  131. std=data_preprocessor.get('std', [1, 1, 1]),
  132. to_bgr=data_preprocessor.get('rgb_to_bgr', False))
  133. reid_img = imrenormalize(img, img_norm_cfg,
  134. self.reid['img_norm_cfg'])
  135. else:
  136. reid_img = img.clone()
  137. valid_inds = scores > self.obj_score_thr
  138. bboxes = bboxes[valid_inds]
  139. labels = labels[valid_inds]
  140. scores = scores[valid_inds]
  141. if self.empty or bboxes.size(0) == 0:
  142. num_new_tracks = bboxes.size(0)
  143. ids = torch.arange(
  144. self.num_tracks,
  145. self.num_tracks + num_new_tracks,
  146. dtype=torch.long).to(bboxes.device)
  147. self.num_tracks += num_new_tracks
  148. if self.with_reid:
  149. crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(),
  150. rescale)
  151. if crops.size(0) > 0:
  152. embeds = model.reid(crops, mode='tensor')
  153. else:
  154. embeds = crops.new_zeros((0, model.reid.head.out_channels))
  155. else:
  156. ids = torch.full((bboxes.size(0), ), -1,
  157. dtype=torch.long).to(bboxes.device)
  158. # motion
  159. if model.with_cmc:
  160. num_samples = 1
  161. self.tracks = model.cmc.track(self.last_img, img, self.tracks,
  162. num_samples, frame_id, metainfo)
  163. self.tracks, motion_dists = self.motion.track(
  164. self.tracks, bbox_xyxy_to_cxcyah(bboxes))
  165. active_ids = self.confirmed_ids
  166. if self.with_reid:
  167. crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(),
  168. rescale)
  169. embeds = model.reid(crops, mode='tensor')
  170. # reid
  171. if len(active_ids) > 0:
  172. track_embeds = self.get(
  173. 'embeds',
  174. active_ids,
  175. self.reid.get('num_samples', None),
  176. behavior='mean')
  177. reid_dists = cosine_distance(track_embeds, embeds)
  178. valid_inds = [list(self.ids).index(_) for _ in active_ids]
  179. reid_dists[~np.isfinite(motion_dists[
  180. valid_inds, :])] = np.nan
  181. weight_motion = self.reid.get('motion_weight')
  182. match_dists = (1 - weight_motion) * reid_dists + \
  183. weight_motion * motion_dists[valid_inds]
  184. # support multi-class association
  185. track_labels = torch.tensor([
  186. self.tracks[id]['labels'][-1] for id in active_ids
  187. ]).to(bboxes.device)
  188. cate_match = labels[None, :] == track_labels[:, None]
  189. cate_cost = ((1 - cate_match.int()) * 1e6).cpu().numpy()
  190. match_dists = match_dists + cate_cost
  191. row, col = linear_sum_assignment(match_dists)
  192. for r, c in zip(row, col):
  193. dist = match_dists[r, c]
  194. if not np.isfinite(dist):
  195. continue
  196. if dist <= self.reid['match_score_thr']:
  197. ids[c] = active_ids[r]
  198. active_ids = [
  199. id for id in self.ids if id not in ids
  200. and self.tracks[id].frame_ids[-1] == frame_id - 1
  201. ]
  202. if len(active_ids) > 0:
  203. active_dets = torch.nonzero(ids == -1).squeeze(1)
  204. track_bboxes = self.get('bboxes', active_ids)
  205. ious = bbox_overlaps(track_bboxes, bboxes[active_dets])
  206. # support multi-class association
  207. track_labels = torch.tensor([
  208. self.tracks[id]['labels'][-1] for id in active_ids
  209. ]).to(bboxes.device)
  210. cate_match = labels[None, active_dets] == track_labels[:, None]
  211. cate_cost = (1 - cate_match.int()) * 1e6
  212. dists = (1 - ious + cate_cost).cpu().numpy()
  213. row, col = linear_sum_assignment(dists)
  214. for r, c in zip(row, col):
  215. dist = dists[r, c]
  216. if dist < 1 - self.match_iou_thr:
  217. ids[active_dets[c]] = active_ids[r]
  218. new_track_inds = ids == -1
  219. ids[new_track_inds] = torch.arange(
  220. self.num_tracks,
  221. self.num_tracks + new_track_inds.sum(),
  222. dtype=torch.long).to(bboxes.device)
  223. self.num_tracks += new_track_inds.sum()
  224. self.update(
  225. ids=ids,
  226. bboxes=bboxes,
  227. scores=scores,
  228. labels=labels,
  229. embeds=embeds if self.with_reid else None,
  230. frame_ids=frame_id)
  231. self.last_img = img
  232. # update pred_track_instances
  233. pred_track_instances = InstanceData()
  234. pred_track_instances.bboxes = bboxes
  235. pred_track_instances.labels = labels
  236. pred_track_instances.scores = scores
  237. pred_track_instances.instances_id = ids
  238. return pred_track_instances