byte_tracker.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. try:
  4. import lap
  5. except ImportError:
  6. lap = None
  7. import numpy as np
  8. import torch
  9. from mmengine.structures import InstanceData
  10. from mmdet.registry import MODELS, TASK_UTILS
  11. from mmdet.structures import DetDataSample
  12. from mmdet.structures.bbox import (bbox_cxcyah_to_xyxy, bbox_overlaps,
  13. bbox_xyxy_to_cxcyah)
  14. from .base_tracker import BaseTracker
  15. @MODELS.register_module()
  16. class ByteTracker(BaseTracker):
  17. """Tracker for ByteTrack.
  18. Args:
  19. motion (dict): Configuration of motion. Defaults to None.
  20. obj_score_thrs (dict): Detection score threshold for matching objects.
  21. - high (float): Threshold of the first matching. Defaults to 0.6.
  22. - low (float): Threshold of the second matching. Defaults to 0.1.
  23. init_track_thr (float): Detection score threshold for initializing a
  24. new tracklet. Defaults to 0.7.
  25. weight_iou_with_det_scores (bool): Whether using detection scores to
  26. weight IOU which is used for matching. Defaults to True.
  27. match_iou_thrs (dict): IOU distance threshold for matching between two
  28. frames.
  29. - high (float): Threshold of the first matching. Defaults to 0.1.
  30. - low (float): Threshold of the second matching. Defaults to 0.5.
  31. - tentative (float): Threshold of the matching for tentative
  32. tracklets. Defaults to 0.3.
  33. num_tentatives (int, optional): Number of continuous frames to confirm
  34. a track. Defaults to 3.
  35. """
  36. def __init__(self,
  37. motion: Optional[dict] = None,
  38. obj_score_thrs: dict = dict(high=0.6, low=0.1),
  39. init_track_thr: float = 0.7,
  40. weight_iou_with_det_scores: bool = True,
  41. match_iou_thrs: dict = dict(high=0.1, low=0.5, tentative=0.3),
  42. num_tentatives: int = 3,
  43. **kwargs):
  44. super().__init__(**kwargs)
  45. if lap is None:
  46. raise RuntimeError('lap is not installed,\
  47. please install it by: pip install lap')
  48. if motion is not None:
  49. self.motion = TASK_UTILS.build(motion)
  50. self.obj_score_thrs = obj_score_thrs
  51. self.init_track_thr = init_track_thr
  52. self.weight_iou_with_det_scores = weight_iou_with_det_scores
  53. self.match_iou_thrs = match_iou_thrs
  54. self.num_tentatives = num_tentatives
  55. @property
  56. def confirmed_ids(self) -> List:
  57. """Confirmed ids in the tracker."""
  58. ids = [id for id, track in self.tracks.items() if not track.tentative]
  59. return ids
  60. @property
  61. def unconfirmed_ids(self) -> List:
  62. """Unconfirmed ids in the tracker."""
  63. ids = [id for id, track in self.tracks.items() if track.tentative]
  64. return ids
  65. def init_track(self, id: int, obj: Tuple[torch.Tensor]) -> None:
  66. """Initialize a track."""
  67. super().init_track(id, obj)
  68. if self.tracks[id].frame_ids[-1] == 0:
  69. self.tracks[id].tentative = False
  70. else:
  71. self.tracks[id].tentative = True
  72. bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
  73. assert bbox.ndim == 2 and bbox.shape[0] == 1
  74. bbox = bbox.squeeze(0).cpu().numpy()
  75. self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate(
  76. bbox)
  77. def update_track(self, id: int, obj: Tuple[torch.Tensor]) -> None:
  78. """Update a track."""
  79. super().update_track(id, obj)
  80. if self.tracks[id].tentative:
  81. if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
  82. self.tracks[id].tentative = False
  83. bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
  84. assert bbox.ndim == 2 and bbox.shape[0] == 1
  85. bbox = bbox.squeeze(0).cpu().numpy()
  86. track_label = self.tracks[id]['labels'][-1]
  87. label_idx = self.memo_items.index('labels')
  88. obj_label = obj[label_idx]
  89. assert obj_label == track_label
  90. self.tracks[id].mean, self.tracks[id].covariance = self.kf.update(
  91. self.tracks[id].mean, self.tracks[id].covariance, bbox)
  92. def pop_invalid_tracks(self, frame_id: int) -> None:
  93. """Pop out invalid tracks."""
  94. invalid_ids = []
  95. for k, v in self.tracks.items():
  96. # case1: disappeared frames >= self.num_frames_retrain
  97. case1 = frame_id - v['frame_ids'][-1] >= self.num_frames_retain
  98. # case2: tentative tracks but not matched in this frame
  99. case2 = v.tentative and v['frame_ids'][-1] != frame_id
  100. if case1 or case2:
  101. invalid_ids.append(k)
  102. for invalid_id in invalid_ids:
  103. self.tracks.pop(invalid_id)
  104. def assign_ids(
  105. self,
  106. ids: List[int],
  107. det_bboxes: torch.Tensor,
  108. det_labels: torch.Tensor,
  109. det_scores: torch.Tensor,
  110. weight_iou_with_det_scores: Optional[bool] = False,
  111. match_iou_thr: Optional[float] = 0.5
  112. ) -> Tuple[np.ndarray, np.ndarray]:
  113. """Assign ids.
  114. Args:
  115. ids (list[int]): Tracking ids.
  116. det_bboxes (Tensor): of shape (N, 4)
  117. det_labels (Tensor): of shape (N,)
  118. det_scores (Tensor): of shape (N,)
  119. weight_iou_with_det_scores (bool, optional): Whether using
  120. detection scores to weight IOU which is used for matching.
  121. Defaults to False.
  122. match_iou_thr (float, optional): Matching threshold.
  123. Defaults to 0.5.
  124. Returns:
  125. tuple(np.ndarray, np.ndarray): The assigning ids.
  126. """
  127. # get track_bboxes
  128. track_bboxes = np.zeros((0, 4))
  129. for id in ids:
  130. track_bboxes = np.concatenate(
  131. (track_bboxes, self.tracks[id].mean[:4][None]), axis=0)
  132. track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes)
  133. track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes)
  134. # compute distance
  135. ious = bbox_overlaps(track_bboxes, det_bboxes)
  136. if weight_iou_with_det_scores:
  137. ious *= det_scores
  138. # support multi-class association
  139. track_labels = torch.tensor([
  140. self.tracks[id]['labels'][-1] for id in ids
  141. ]).to(det_bboxes.device)
  142. cate_match = det_labels[None, :] == track_labels[:, None]
  143. # to avoid det and track of different categories are matched
  144. cate_cost = (1 - cate_match.int()) * 1e6
  145. dists = (1 - ious + cate_cost).cpu().numpy()
  146. # bipartite match
  147. if dists.size > 0:
  148. cost, row, col = lap.lapjv(
  149. dists, extend_cost=True, cost_limit=1 - match_iou_thr)
  150. else:
  151. row = np.zeros(len(ids)).astype(np.int32) - 1
  152. col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
  153. return row, col
  154. def track(self, data_sample: DetDataSample, **kwargs) -> InstanceData:
  155. """Tracking forward function.
  156. Args:
  157. data_sample (:obj:`DetDataSample`): The data sample.
  158. It includes information such as `pred_instances`.
  159. Returns:
  160. :obj:`InstanceData`: Tracking results of the input images.
  161. Each InstanceData usually contains ``bboxes``, ``labels``,
  162. ``scores`` and ``instances_id``.
  163. """
  164. metainfo = data_sample.metainfo
  165. bboxes = data_sample.pred_instances.bboxes
  166. labels = data_sample.pred_instances.labels
  167. scores = data_sample.pred_instances.scores
  168. frame_id = metainfo.get('frame_id', -1)
  169. if frame_id == 0:
  170. self.reset()
  171. if not hasattr(self, 'kf'):
  172. self.kf = self.motion
  173. if self.empty or bboxes.size(0) == 0:
  174. valid_inds = scores > self.init_track_thr
  175. scores = scores[valid_inds]
  176. bboxes = bboxes[valid_inds]
  177. labels = labels[valid_inds]
  178. num_new_tracks = bboxes.size(0)
  179. ids = torch.arange(self.num_tracks,
  180. self.num_tracks + num_new_tracks).to(labels)
  181. self.num_tracks += num_new_tracks
  182. else:
  183. # 0. init
  184. ids = torch.full((bboxes.size(0), ),
  185. -1,
  186. dtype=labels.dtype,
  187. device=labels.device)
  188. # get the detection bboxes for the first association
  189. first_det_inds = scores > self.obj_score_thrs['high']
  190. first_det_bboxes = bboxes[first_det_inds]
  191. first_det_labels = labels[first_det_inds]
  192. first_det_scores = scores[first_det_inds]
  193. first_det_ids = ids[first_det_inds]
  194. # get the detection bboxes for the second association
  195. second_det_inds = (~first_det_inds) & (
  196. scores > self.obj_score_thrs['low'])
  197. second_det_bboxes = bboxes[second_det_inds]
  198. second_det_labels = labels[second_det_inds]
  199. second_det_scores = scores[second_det_inds]
  200. second_det_ids = ids[second_det_inds]
  201. # 1. use Kalman Filter to predict current location
  202. for id in self.confirmed_ids:
  203. # track is lost in previous frame
  204. if self.tracks[id].frame_ids[-1] != frame_id - 1:
  205. self.tracks[id].mean[7] = 0
  206. (self.tracks[id].mean,
  207. self.tracks[id].covariance) = self.kf.predict(
  208. self.tracks[id].mean, self.tracks[id].covariance)
  209. # 2. first match
  210. first_match_track_inds, first_match_det_inds = self.assign_ids(
  211. self.confirmed_ids, first_det_bboxes, first_det_labels,
  212. first_det_scores, self.weight_iou_with_det_scores,
  213. self.match_iou_thrs['high'])
  214. # '-1' mean a detection box is not matched with tracklets in
  215. # previous frame
  216. valid = first_match_det_inds > -1
  217. first_det_ids[valid] = torch.tensor(
  218. self.confirmed_ids)[first_match_det_inds[valid]].to(labels)
  219. first_match_det_bboxes = first_det_bboxes[valid]
  220. first_match_det_labels = first_det_labels[valid]
  221. first_match_det_scores = first_det_scores[valid]
  222. first_match_det_ids = first_det_ids[valid]
  223. assert (first_match_det_ids > -1).all()
  224. first_unmatch_det_bboxes = first_det_bboxes[~valid]
  225. first_unmatch_det_labels = first_det_labels[~valid]
  226. first_unmatch_det_scores = first_det_scores[~valid]
  227. first_unmatch_det_ids = first_det_ids[~valid]
  228. assert (first_unmatch_det_ids == -1).all()
  229. # 3. use unmatched detection bboxes from the first match to match
  230. # the unconfirmed tracks
  231. (tentative_match_track_inds,
  232. tentative_match_det_inds) = self.assign_ids(
  233. self.unconfirmed_ids, first_unmatch_det_bboxes,
  234. first_unmatch_det_labels, first_unmatch_det_scores,
  235. self.weight_iou_with_det_scores,
  236. self.match_iou_thrs['tentative'])
  237. valid = tentative_match_det_inds > -1
  238. first_unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_ids)[
  239. tentative_match_det_inds[valid]].to(labels)
  240. # 4. second match for unmatched tracks from the first match
  241. first_unmatch_track_ids = []
  242. for i, id in enumerate(self.confirmed_ids):
  243. # tracklet is not matched in the first match
  244. case_1 = first_match_track_inds[i] == -1
  245. # tracklet is not lost in the previous frame
  246. case_2 = self.tracks[id].frame_ids[-1] == frame_id - 1
  247. if case_1 and case_2:
  248. first_unmatch_track_ids.append(id)
  249. second_match_track_inds, second_match_det_inds = self.assign_ids(
  250. first_unmatch_track_ids, second_det_bboxes, second_det_labels,
  251. second_det_scores, False, self.match_iou_thrs['low'])
  252. valid = second_match_det_inds > -1
  253. second_det_ids[valid] = torch.tensor(first_unmatch_track_ids)[
  254. second_match_det_inds[valid]].to(ids)
  255. # 5. gather all matched detection bboxes from step 2-4
  256. # we only keep matched detection bboxes in second match, which
  257. # means the id != -1
  258. valid = second_det_ids > -1
  259. bboxes = torch.cat(
  260. (first_match_det_bboxes, first_unmatch_det_bboxes), dim=0)
  261. bboxes = torch.cat((bboxes, second_det_bboxes[valid]), dim=0)
  262. labels = torch.cat(
  263. (first_match_det_labels, first_unmatch_det_labels), dim=0)
  264. labels = torch.cat((labels, second_det_labels[valid]), dim=0)
  265. scores = torch.cat(
  266. (first_match_det_scores, first_unmatch_det_scores), dim=0)
  267. scores = torch.cat((scores, second_det_scores[valid]), dim=0)
  268. ids = torch.cat((first_match_det_ids, first_unmatch_det_ids),
  269. dim=0)
  270. ids = torch.cat((ids, second_det_ids[valid]), dim=0)
  271. # 6. assign new ids
  272. new_track_inds = ids == -1
  273. ids[new_track_inds] = torch.arange(
  274. self.num_tracks,
  275. self.num_tracks + new_track_inds.sum()).to(labels)
  276. self.num_tracks += new_track_inds.sum()
  277. self.update(
  278. ids=ids,
  279. bboxes=bboxes,
  280. scores=scores,
  281. labels=labels,
  282. frame_ids=frame_id)
  283. # update pred_track_instances
  284. pred_track_instances = InstanceData()
  285. pred_track_instances.bboxes = bboxes
  286. pred_track_instances.labels = labels
  287. pred_track_instances.scores = scores
  288. pred_track_instances.instances_id = ids
  289. return pred_track_instances