ocsort_tracker.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. try:
  4. import lap
  5. except ImportError:
  6. lap = None
  7. import numpy as np
  8. import torch
  9. from addict import Dict
  10. from mmengine.structures import InstanceData
  11. from mmdet.registry import MODELS
  12. from mmdet.structures import DetDataSample
  13. from mmdet.structures.bbox import (bbox_cxcyah_to_xyxy, bbox_overlaps,
  14. bbox_xyxy_to_cxcyah)
  15. from .sort_tracker import SORTTracker
  16. @MODELS.register_module()
  17. class OCSORTTracker(SORTTracker):
  18. """Tracker for OC-SORT.
  19. Args:
  20. motion (dict): Configuration of motion. Defaults to None.
  21. obj_score_thrs (float): Detection score threshold for matching objects.
  22. Defaults to 0.3.
  23. init_track_thr (float): Detection score threshold for initializing a
  24. new tracklet. Defaults to 0.7.
  25. weight_iou_with_det_scores (bool): Whether using detection scores to
  26. weight IOU which is used for matching. Defaults to True.
  27. match_iou_thr (float): IOU distance threshold for matching between two
  28. frames. Defaults to 0.3.
  29. num_tentatives (int, optional): Number of continuous frames to confirm
  30. a track. Defaults to 3.
  31. vel_consist_weight (float): Weight of the velocity consistency term in
  32. association (OCM term in the paper).
  33. vel_delta_t (int): The difference of time step for calculating of the
  34. velocity direction of tracklets.
  35. init_cfg (dict or list[dict], optional): Initialization config dict.
  36. Defaults to None.
  37. """
  38. def __init__(self,
  39. motion: Optional[dict] = None,
  40. obj_score_thr: float = 0.3,
  41. init_track_thr: float = 0.7,
  42. weight_iou_with_det_scores: bool = True,
  43. match_iou_thr: float = 0.3,
  44. num_tentatives: int = 3,
  45. vel_consist_weight: float = 0.2,
  46. vel_delta_t: int = 3,
  47. **kwargs):
  48. if lap is None:
  49. raise RuntimeError('lap is not installed,\
  50. please install it by: pip install lap')
  51. super().__init__(motion=motion, **kwargs)
  52. self.obj_score_thr = obj_score_thr
  53. self.init_track_thr = init_track_thr
  54. self.weight_iou_with_det_scores = weight_iou_with_det_scores
  55. self.match_iou_thr = match_iou_thr
  56. self.vel_consist_weight = vel_consist_weight
  57. self.vel_delta_t = vel_delta_t
  58. self.num_tentatives = num_tentatives
  59. @property
  60. def unconfirmed_ids(self):
  61. """Unconfirmed ids in the tracker."""
  62. ids = [id for id, track in self.tracks.items() if track.tentative]
  63. return ids
  64. def init_track(self, id: int, obj: Tuple[torch.Tensor]):
  65. """Initialize a track."""
  66. super().init_track(id, obj)
  67. if self.tracks[id].frame_ids[-1] == 0:
  68. self.tracks[id].tentative = False
  69. else:
  70. self.tracks[id].tentative = True
  71. bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
  72. assert bbox.ndim == 2 and bbox.shape[0] == 1
  73. bbox = bbox.squeeze(0).cpu().numpy()
  74. self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate(
  75. bbox)
  76. # track.obs maintains the history associated detections to this track
  77. self.tracks[id].obs = []
  78. bbox_id = self.memo_items.index('bboxes')
  79. self.tracks[id].obs.append(obj[bbox_id])
  80. # a placefolder to save mean/covariance before losing tracking it
  81. # parameters to save: mean, covariance, measurement
  82. self.tracks[id].tracked = True
  83. self.tracks[id].saved_attr = Dict()
  84. self.tracks[id].velocity = torch.tensor(
  85. (-1, -1)).to(obj[bbox_id].device) # placeholder
  86. def update_track(self, id: int, obj: Tuple[torch.Tensor]):
  87. """Update a track."""
  88. super().update_track(id, obj)
  89. if self.tracks[id].tentative:
  90. if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
  91. self.tracks[id].tentative = False
  92. bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
  93. assert bbox.ndim == 2 and bbox.shape[0] == 1
  94. bbox = bbox.squeeze(0).cpu().numpy()
  95. self.tracks[id].mean, self.tracks[id].covariance = self.kf.update(
  96. self.tracks[id].mean, self.tracks[id].covariance, bbox)
  97. self.tracks[id].tracked = True
  98. bbox_id = self.memo_items.index('bboxes')
  99. self.tracks[id].obs.append(obj[bbox_id])
  100. bbox1 = self.k_step_observation(self.tracks[id])
  101. bbox2 = obj[bbox_id]
  102. self.tracks[id].velocity = self.vel_direction(bbox1, bbox2).to(
  103. obj[bbox_id].device)
  104. def vel_direction(self, bbox1: torch.Tensor, bbox2: torch.Tensor):
  105. """Estimate the direction vector between two boxes."""
  106. if bbox1.sum() < 0 or bbox2.sum() < 0:
  107. return torch.tensor((-1, -1))
  108. cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
  109. cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
  110. speed = torch.tensor([cy2 - cy1, cx2 - cx1])
  111. norm = torch.sqrt((speed[0])**2 + (speed[1])**2) + 1e-6
  112. return speed / norm
  113. def vel_direction_batch(self, bboxes1: torch.Tensor,
  114. bboxes2: torch.Tensor):
  115. """Estimate the direction vector given two batches of boxes."""
  116. cx1, cy1 = (bboxes1[:, 0] + bboxes1[:, 2]) / 2.0, (bboxes1[:, 1] +
  117. bboxes1[:, 3]) / 2.0
  118. cx2, cy2 = (bboxes2[:, 0] + bboxes2[:, 2]) / 2.0, (bboxes2[:, 1] +
  119. bboxes2[:, 3]) / 2.0
  120. speed_diff_y = cy2[None, :] - cy1[:, None]
  121. speed_diff_x = cx2[None, :] - cx1[:, None]
  122. speed = torch.cat((speed_diff_y[..., None], speed_diff_x[..., None]),
  123. dim=-1)
  124. norm = torch.sqrt((speed[:, :, 0])**2 + (speed[:, :, 1])**2) + 1e-6
  125. speed[:, :, 0] /= norm
  126. speed[:, :, 1] /= norm
  127. return speed
  128. def k_step_observation(self, track: Dict):
  129. """return the observation k step away before."""
  130. obs_seqs = track.obs
  131. num_obs = len(obs_seqs)
  132. if num_obs == 0:
  133. return torch.tensor((-1, -1, -1, -1)).to(track.obs[0].device)
  134. elif num_obs > self.vel_delta_t:
  135. if obs_seqs[num_obs - 1 - self.vel_delta_t] is not None:
  136. return obs_seqs[num_obs - 1 - self.vel_delta_t]
  137. else:
  138. return self.last_obs(track)
  139. else:
  140. return self.last_obs(track)
  141. def ocm_assign_ids(self,
  142. ids: List[int],
  143. det_bboxes: torch.Tensor,
  144. det_labels: torch.Tensor,
  145. det_scores: torch.Tensor,
  146. weight_iou_with_det_scores: Optional[bool] = False,
  147. match_iou_thr: Optional[float] = 0.5):
  148. """Apply Observation-Centric Momentum (OCM) to assign ids.
  149. OCM adds movement direction consistency into the association cost
  150. matrix. This term requires no additional assumption but from the
  151. same linear motion assumption as the canonical Kalman Filter in SORT.
  152. Args:
  153. ids (list[int]): Tracking ids.
  154. det_bboxes (Tensor): of shape (N, 4)
  155. det_labels (Tensor): of shape (N,)
  156. det_scores (Tensor): of shape (N,)
  157. weight_iou_with_det_scores (bool, optional): Whether using
  158. detection scores to weight IOU which is used for matching.
  159. Defaults to False.
  160. match_iou_thr (float, optional): Matching threshold.
  161. Defaults to 0.5.
  162. Returns:
  163. tuple(int): The assigning ids.
  164. OC-SORT uses velocity consistency besides IoU for association
  165. """
  166. # get track_bboxes
  167. track_bboxes = np.zeros((0, 4))
  168. for id in ids:
  169. track_bboxes = np.concatenate(
  170. (track_bboxes, self.tracks[id].mean[:4][None]), axis=0)
  171. track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes)
  172. track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes)
  173. # compute distance
  174. ious = bbox_overlaps(track_bboxes, det_bboxes)
  175. if weight_iou_with_det_scores:
  176. ious *= det_scores
  177. # support multi-class association
  178. track_labels = torch.tensor([
  179. self.tracks[id]['labels'][-1] for id in ids
  180. ]).to(det_bboxes.device)
  181. cate_match = det_labels[None, :] == track_labels[:, None]
  182. # to avoid det and track of different categories are matched
  183. cate_cost = (1 - cate_match.int()) * 1e6
  184. dists = (1 - ious + cate_cost).cpu().numpy()
  185. if len(ids) > 0 and len(det_bboxes) > 0:
  186. track_velocities = torch.stack(
  187. [self.tracks[id].velocity for id in ids]).to(det_bboxes.device)
  188. k_step_observations = torch.stack([
  189. self.k_step_observation(self.tracks[id]) for id in ids
  190. ]).to(det_bboxes.device)
  191. # valid1: if the track has previous observations to estimate speed
  192. # valid2: if the associated observation k steps ago is a detection
  193. valid1 = track_velocities.sum(dim=1) != -2
  194. valid2 = k_step_observations.sum(dim=1) != -4
  195. valid = valid1 & valid2
  196. vel_to_match = self.vel_direction_batch(k_step_observations,
  197. det_bboxes)
  198. track_velocities = track_velocities[:, None, :].repeat(
  199. 1, det_bboxes.shape[0], 1)
  200. angle_cos = (vel_to_match * track_velocities).sum(dim=-1)
  201. angle_cos = torch.clamp(angle_cos, min=-1, max=1)
  202. angle = torch.acos(angle_cos) # [0, pi]
  203. norm_angle = (angle - np.pi / 2.) / np.pi # [-0.5, 0.5]
  204. valid_matrix = valid[:, None].int().repeat(1, det_bboxes.shape[0])
  205. # set non-valid entries 0
  206. valid_norm_angle = norm_angle * valid_matrix
  207. dists += valid_norm_angle.cpu().numpy() * self.vel_consist_weight
  208. # bipartite match
  209. if dists.size > 0:
  210. cost, row, col = lap.lapjv(
  211. dists, extend_cost=True, cost_limit=1 - match_iou_thr)
  212. else:
  213. row = np.zeros(len(ids)).astype(np.int32) - 1
  214. col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
  215. return row, col
  216. def last_obs(self, track: Dict):
  217. """extract the last associated observation."""
  218. for bbox in track.obs[::-1]:
  219. if bbox is not None:
  220. return bbox
  221. def ocr_assign_ids(self,
  222. track_obs: torch.Tensor,
  223. last_track_labels: torch.Tensor,
  224. det_bboxes: torch.Tensor,
  225. det_labels: torch.Tensor,
  226. det_scores: torch.Tensor,
  227. weight_iou_with_det_scores: Optional[bool] = False,
  228. match_iou_thr: Optional[float] = 0.5):
  229. """association for Observation-Centric Recovery.
  230. As try to recover tracks from being lost whose estimated velocity is
  231. out- to-date, we use IoU-only matching strategy.
  232. Args:
  233. track_obs (Tensor): the list of historical associated
  234. detections of tracks
  235. det_bboxes (Tensor): of shape (N, 5), unmatched detections
  236. det_labels (Tensor): of shape (N,)
  237. det_scores (Tensor): of shape (N,)
  238. weight_iou_with_det_scores (bool, optional): Whether using
  239. detection scores to weight IOU which is used for matching.
  240. Defaults to False.
  241. match_iou_thr (float, optional): Matching threshold.
  242. Defaults to 0.5.
  243. Returns:
  244. tuple(int): The assigning ids.
  245. """
  246. # compute distance
  247. ious = bbox_overlaps(track_obs, det_bboxes)
  248. if weight_iou_with_det_scores:
  249. ious *= det_scores
  250. # support multi-class association
  251. cate_match = det_labels[None, :] == last_track_labels[:, None]
  252. # to avoid det and track of different categories are matched
  253. cate_cost = (1 - cate_match.int()) * 1e6
  254. dists = (1 - ious + cate_cost).cpu().numpy()
  255. # bipartite match
  256. if dists.size > 0:
  257. cost, row, col = lap.lapjv(
  258. dists, extend_cost=True, cost_limit=1 - match_iou_thr)
  259. else:
  260. row = np.zeros(len(track_obs)).astype(np.int32) - 1
  261. col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
  262. return row, col
  263. def online_smooth(self, track: Dict, obj: torch.Tensor):
  264. """Once a track is recovered from being lost, online smooth its
  265. parameters to fix the error accumulated during being lost.
  266. NOTE: you can use different virtual trajectory generation
  267. strategies, we adopt the naive linear interpolation as default
  268. """
  269. last_match_bbox = self.last_obs(track)
  270. new_match_bbox = obj
  271. unmatch_len = 0
  272. for bbox in track.obs[::-1]:
  273. if bbox is None:
  274. unmatch_len += 1
  275. else:
  276. break
  277. bbox_shift_per_step = (new_match_bbox - last_match_bbox) / (
  278. unmatch_len + 1)
  279. track.mean = track.saved_attr.mean
  280. track.covariance = track.saved_attr.covariance
  281. for i in range(unmatch_len):
  282. virtual_bbox = last_match_bbox + (i + 1) * bbox_shift_per_step
  283. virtual_bbox = bbox_xyxy_to_cxcyah(virtual_bbox[None, :])
  284. virtual_bbox = virtual_bbox.squeeze(0).cpu().numpy()
  285. track.mean, track.covariance = self.kf.update(
  286. track.mean, track.covariance, virtual_bbox)
  287. def track(self, data_sample: DetDataSample, **kwargs) -> InstanceData:
  288. """Tracking forward function.
  289. NOTE: this implementation is slightly different from the original
  290. OC-SORT implementation (https://github.com/noahcao/OC_SORT)that we
  291. do association between detections and tentative/non-tentative tracks
  292. independently while the original implementation combines them together.
  293. Args:
  294. data_sample (:obj:`DetDataSample`): The data sample.
  295. It includes information such as `pred_instances`.
  296. Returns:
  297. :obj:`InstanceData`: Tracking results of the input images.
  298. Each InstanceData usually contains ``bboxes``, ``labels``,
  299. ``scores`` and ``instances_id``.
  300. """
  301. metainfo = data_sample.metainfo
  302. bboxes = data_sample.pred_instances.bboxes
  303. labels = data_sample.pred_instances.labels
  304. scores = data_sample.pred_instances.scores
  305. frame_id = metainfo.get('frame_id', -1)
  306. if frame_id == 0:
  307. self.reset()
  308. if not hasattr(self, 'kf'):
  309. self.kf = self.motion
  310. if self.empty or bboxes.size(0) == 0:
  311. valid_inds = scores > self.init_track_thr
  312. scores = scores[valid_inds]
  313. bboxes = bboxes[valid_inds]
  314. labels = labels[valid_inds]
  315. num_new_tracks = bboxes.size(0)
  316. ids = torch.arange(self.num_tracks,
  317. self.num_tracks + num_new_tracks).to(labels)
  318. self.num_tracks += num_new_tracks
  319. else:
  320. # 0. init
  321. ids = torch.full((bboxes.size(0), ),
  322. -1,
  323. dtype=labels.dtype,
  324. device=labels.device)
  325. # get the detection bboxes for the first association
  326. det_inds = scores > self.obj_score_thr
  327. det_bboxes = bboxes[det_inds]
  328. det_labels = labels[det_inds]
  329. det_scores = scores[det_inds]
  330. det_ids = ids[det_inds]
  331. # 1. predict by Kalman Filter
  332. for id in self.confirmed_ids:
  333. # track is lost in previous frame
  334. if self.tracks[id].frame_ids[-1] != frame_id - 1:
  335. self.tracks[id].mean[7] = 0
  336. if self.tracks[id].tracked:
  337. self.tracks[id].saved_attr.mean = self.tracks[id].mean
  338. self.tracks[id].saved_attr.covariance = self.tracks[
  339. id].covariance
  340. (self.tracks[id].mean,
  341. self.tracks[id].covariance) = self.kf.predict(
  342. self.tracks[id].mean, self.tracks[id].covariance)
  343. # 2. match detections and tracks' predicted locations
  344. match_track_inds, raw_match_det_inds = self.ocm_assign_ids(
  345. self.confirmed_ids, det_bboxes, det_labels, det_scores,
  346. self.weight_iou_with_det_scores, self.match_iou_thr)
  347. # '-1' mean a detection box is not matched with tracklets in
  348. # previous frame
  349. valid = raw_match_det_inds > -1
  350. det_ids[valid] = torch.tensor(
  351. self.confirmed_ids)[raw_match_det_inds[valid]].to(labels)
  352. match_det_bboxes = det_bboxes[valid]
  353. match_det_labels = det_labels[valid]
  354. match_det_scores = det_scores[valid]
  355. match_det_ids = det_ids[valid]
  356. assert (match_det_ids > -1).all()
  357. # unmatched tracks and detections
  358. unmatch_det_bboxes = det_bboxes[~valid]
  359. unmatch_det_labels = det_labels[~valid]
  360. unmatch_det_scores = det_scores[~valid]
  361. unmatch_det_ids = det_ids[~valid]
  362. assert (unmatch_det_ids == -1).all()
  363. # 3. use unmatched detection bboxes from the first match to match
  364. # the unconfirmed tracks
  365. (tentative_match_track_inds,
  366. tentative_match_det_inds) = self.ocm_assign_ids(
  367. self.unconfirmed_ids, unmatch_det_bboxes, unmatch_det_labels,
  368. unmatch_det_scores, self.weight_iou_with_det_scores,
  369. self.match_iou_thr)
  370. valid = tentative_match_det_inds > -1
  371. unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_ids)[
  372. tentative_match_det_inds[valid]].to(labels)
  373. match_det_bboxes = torch.cat(
  374. (match_det_bboxes, unmatch_det_bboxes[valid]), dim=0)
  375. match_det_labels = torch.cat(
  376. (match_det_labels, unmatch_det_labels[valid]), dim=0)
  377. match_det_scores = torch.cat(
  378. (match_det_scores, unmatch_det_scores[valid]), dim=0)
  379. match_det_ids = torch.cat((match_det_ids, unmatch_det_ids[valid]),
  380. dim=0)
  381. assert (match_det_ids > -1).all()
  382. unmatch_det_bboxes = unmatch_det_bboxes[~valid]
  383. unmatch_det_labels = unmatch_det_labels[~valid]
  384. unmatch_det_scores = unmatch_det_scores[~valid]
  385. unmatch_det_ids = unmatch_det_ids[~valid]
  386. assert (unmatch_det_ids == -1).all()
  387. all_track_ids = [id for id, _ in self.tracks.items()]
  388. unmatched_track_inds = torch.tensor(
  389. [ind for ind in all_track_ids if ind not in match_det_ids])
  390. if len(unmatched_track_inds) > 0:
  391. # 4. still some tracks not associated yet, perform OCR
  392. last_observations = []
  393. for id in unmatched_track_inds:
  394. last_box = self.last_obs(self.tracks[id.item()])
  395. last_observations.append(last_box)
  396. last_observations = torch.stack(last_observations)
  397. last_track_labels = torch.tensor([
  398. self.tracks[id.item()]['labels'][-1]
  399. for id in unmatched_track_inds
  400. ]).to(det_bboxes.device)
  401. remain_det_ids = torch.full((unmatch_det_bboxes.size(0), ),
  402. -1,
  403. dtype=labels.dtype,
  404. device=labels.device)
  405. _, ocr_match_det_inds = self.ocr_assign_ids(
  406. last_observations, last_track_labels, unmatch_det_bboxes,
  407. unmatch_det_labels, unmatch_det_scores,
  408. self.weight_iou_with_det_scores, self.match_iou_thr)
  409. valid = ocr_match_det_inds > -1
  410. remain_det_ids[valid] = unmatched_track_inds.clone()[
  411. ocr_match_det_inds[valid]].to(labels)
  412. ocr_match_det_bboxes = unmatch_det_bboxes[valid]
  413. ocr_match_det_labels = unmatch_det_labels[valid]
  414. ocr_match_det_scores = unmatch_det_scores[valid]
  415. ocr_match_det_ids = remain_det_ids[valid]
  416. assert (ocr_match_det_ids > -1).all()
  417. ocr_unmatch_det_bboxes = unmatch_det_bboxes[~valid]
  418. ocr_unmatch_det_labels = unmatch_det_labels[~valid]
  419. ocr_unmatch_det_scores = unmatch_det_scores[~valid]
  420. ocr_unmatch_det_ids = remain_det_ids[~valid]
  421. assert (ocr_unmatch_det_ids == -1).all()
  422. unmatch_det_bboxes = ocr_unmatch_det_bboxes
  423. unmatch_det_labels = ocr_unmatch_det_labels
  424. unmatch_det_scores = ocr_unmatch_det_scores
  425. unmatch_det_ids = ocr_unmatch_det_ids
  426. match_det_bboxes = torch.cat(
  427. (match_det_bboxes, ocr_match_det_bboxes), dim=0)
  428. match_det_labels = torch.cat(
  429. (match_det_labels, ocr_match_det_labels), dim=0)
  430. match_det_scores = torch.cat(
  431. (match_det_scores, ocr_match_det_scores), dim=0)
  432. match_det_ids = torch.cat((match_det_ids, ocr_match_det_ids),
  433. dim=0)
  434. # 5. summarize the track results
  435. for i in range(len(match_det_ids)):
  436. det_bbox = match_det_bboxes[i]
  437. track_id = match_det_ids[i].item()
  438. if not self.tracks[track_id].tracked:
  439. # the track is lost before this step
  440. self.online_smooth(self.tracks[track_id], det_bbox)
  441. for track_id in all_track_ids:
  442. if track_id not in match_det_ids:
  443. self.tracks[track_id].tracked = False
  444. self.tracks[track_id].obs.append(None)
  445. bboxes = torch.cat((match_det_bboxes, unmatch_det_bboxes), dim=0)
  446. labels = torch.cat((match_det_labels, unmatch_det_labels), dim=0)
  447. scores = torch.cat((match_det_scores, unmatch_det_scores), dim=0)
  448. ids = torch.cat((match_det_ids, unmatch_det_ids), dim=0)
  449. # 6. assign new ids
  450. new_track_inds = ids == -1
  451. ids[new_track_inds] = torch.arange(
  452. self.num_tracks,
  453. self.num_tracks + new_track_inds.sum()).to(labels)
  454. self.num_tracks += new_track_inds.sum()
  455. self.update(
  456. ids=ids,
  457. bboxes=bboxes,
  458. labels=labels,
  459. scores=scores,
  460. frame_ids=frame_id)
  461. # update pred_track_instances
  462. pred_track_instances = InstanceData()
  463. pred_track_instances.bboxes = bboxes
  464. pred_track_instances.labels = labels
  465. pred_track_instances.scores = scores
  466. pred_track_instances.instances_id = ids
  467. return pred_track_instances