123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Optional, Tuple
- try:
- import lap
- except ImportError:
- lap = None
- import numpy as np
- import torch
- from mmengine.structures import InstanceData
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.structures import DetDataSample
- from mmdet.structures.bbox import (bbox_cxcyah_to_xyxy, bbox_overlaps,
- bbox_xyxy_to_cxcyah)
- from .base_tracker import BaseTracker
- @MODELS.register_module()
- class ByteTracker(BaseTracker):
- """Tracker for ByteTrack.
- Args:
- motion (dict): Configuration of motion. Defaults to None.
- obj_score_thrs (dict): Detection score threshold for matching objects.
- - high (float): Threshold of the first matching. Defaults to 0.6.
- - low (float): Threshold of the second matching. Defaults to 0.1.
- init_track_thr (float): Detection score threshold for initializing a
- new tracklet. Defaults to 0.7.
- weight_iou_with_det_scores (bool): Whether using detection scores to
- weight IOU which is used for matching. Defaults to True.
- match_iou_thrs (dict): IOU distance threshold for matching between two
- frames.
- - high (float): Threshold of the first matching. Defaults to 0.1.
- - low (float): Threshold of the second matching. Defaults to 0.5.
- - tentative (float): Threshold of the matching for tentative
- tracklets. Defaults to 0.3.
- num_tentatives (int, optional): Number of continuous frames to confirm
- a track. Defaults to 3.
- """
- def __init__(self,
- motion: Optional[dict] = None,
- obj_score_thrs: dict = dict(high=0.6, low=0.1),
- init_track_thr: float = 0.7,
- weight_iou_with_det_scores: bool = True,
- match_iou_thrs: dict = dict(high=0.1, low=0.5, tentative=0.3),
- num_tentatives: int = 3,
- **kwargs):
- super().__init__(**kwargs)
- if lap is None:
- raise RuntimeError('lap is not installed,\
- please install it by: pip install lap')
- if motion is not None:
- self.motion = TASK_UTILS.build(motion)
- self.obj_score_thrs = obj_score_thrs
- self.init_track_thr = init_track_thr
- self.weight_iou_with_det_scores = weight_iou_with_det_scores
- self.match_iou_thrs = match_iou_thrs
- self.num_tentatives = num_tentatives
- @property
- def confirmed_ids(self) -> List:
- """Confirmed ids in the tracker."""
- ids = [id for id, track in self.tracks.items() if not track.tentative]
- return ids
- @property
- def unconfirmed_ids(self) -> List:
- """Unconfirmed ids in the tracker."""
- ids = [id for id, track in self.tracks.items() if track.tentative]
- return ids
- def init_track(self, id: int, obj: Tuple[torch.Tensor]) -> None:
- """Initialize a track."""
- super().init_track(id, obj)
- if self.tracks[id].frame_ids[-1] == 0:
- self.tracks[id].tentative = False
- else:
- self.tracks[id].tentative = True
- bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
- assert bbox.ndim == 2 and bbox.shape[0] == 1
- bbox = bbox.squeeze(0).cpu().numpy()
- self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate(
- bbox)
- def update_track(self, id: int, obj: Tuple[torch.Tensor]) -> None:
- """Update a track."""
- super().update_track(id, obj)
- if self.tracks[id].tentative:
- if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
- self.tracks[id].tentative = False
- bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
- assert bbox.ndim == 2 and bbox.shape[0] == 1
- bbox = bbox.squeeze(0).cpu().numpy()
- track_label = self.tracks[id]['labels'][-1]
- label_idx = self.memo_items.index('labels')
- obj_label = obj[label_idx]
- assert obj_label == track_label
- self.tracks[id].mean, self.tracks[id].covariance = self.kf.update(
- self.tracks[id].mean, self.tracks[id].covariance, bbox)
- def pop_invalid_tracks(self, frame_id: int) -> None:
- """Pop out invalid tracks."""
- invalid_ids = []
- for k, v in self.tracks.items():
- # case1: disappeared frames >= self.num_frames_retrain
- case1 = frame_id - v['frame_ids'][-1] >= self.num_frames_retain
- # case2: tentative tracks but not matched in this frame
- case2 = v.tentative and v['frame_ids'][-1] != frame_id
- if case1 or case2:
- invalid_ids.append(k)
- for invalid_id in invalid_ids:
- self.tracks.pop(invalid_id)
- def assign_ids(
- self,
- ids: List[int],
- det_bboxes: torch.Tensor,
- det_labels: torch.Tensor,
- det_scores: torch.Tensor,
- weight_iou_with_det_scores: Optional[bool] = False,
- match_iou_thr: Optional[float] = 0.5
- ) -> Tuple[np.ndarray, np.ndarray]:
- """Assign ids.
- Args:
- ids (list[int]): Tracking ids.
- det_bboxes (Tensor): of shape (N, 4)
- det_labels (Tensor): of shape (N,)
- det_scores (Tensor): of shape (N,)
- weight_iou_with_det_scores (bool, optional): Whether using
- detection scores to weight IOU which is used for matching.
- Defaults to False.
- match_iou_thr (float, optional): Matching threshold.
- Defaults to 0.5.
- Returns:
- tuple(np.ndarray, np.ndarray): The assigning ids.
- """
- # get track_bboxes
- track_bboxes = np.zeros((0, 4))
- for id in ids:
- track_bboxes = np.concatenate(
- (track_bboxes, self.tracks[id].mean[:4][None]), axis=0)
- track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes)
- track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes)
- # compute distance
- ious = bbox_overlaps(track_bboxes, det_bboxes)
- if weight_iou_with_det_scores:
- ious *= det_scores
- # support multi-class association
- track_labels = torch.tensor([
- self.tracks[id]['labels'][-1] for id in ids
- ]).to(det_bboxes.device)
- cate_match = det_labels[None, :] == track_labels[:, None]
- # to avoid det and track of different categories are matched
- cate_cost = (1 - cate_match.int()) * 1e6
- dists = (1 - ious + cate_cost).cpu().numpy()
- # bipartite match
- if dists.size > 0:
- cost, row, col = lap.lapjv(
- dists, extend_cost=True, cost_limit=1 - match_iou_thr)
- else:
- row = np.zeros(len(ids)).astype(np.int32) - 1
- col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
- return row, col
- def track(self, data_sample: DetDataSample, **kwargs) -> InstanceData:
- """Tracking forward function.
- Args:
- data_sample (:obj:`DetDataSample`): The data sample.
- It includes information such as `pred_instances`.
- Returns:
- :obj:`InstanceData`: Tracking results of the input images.
- Each InstanceData usually contains ``bboxes``, ``labels``,
- ``scores`` and ``instances_id``.
- """
- metainfo = data_sample.metainfo
- bboxes = data_sample.pred_instances.bboxes
- labels = data_sample.pred_instances.labels
- scores = data_sample.pred_instances.scores
- frame_id = metainfo.get('frame_id', -1)
- if frame_id == 0:
- self.reset()
- if not hasattr(self, 'kf'):
- self.kf = self.motion
- if self.empty or bboxes.size(0) == 0:
- valid_inds = scores > self.init_track_thr
- scores = scores[valid_inds]
- bboxes = bboxes[valid_inds]
- labels = labels[valid_inds]
- num_new_tracks = bboxes.size(0)
- ids = torch.arange(self.num_tracks,
- self.num_tracks + num_new_tracks).to(labels)
- self.num_tracks += num_new_tracks
- else:
- # 0. init
- ids = torch.full((bboxes.size(0), ),
- -1,
- dtype=labels.dtype,
- device=labels.device)
- # get the detection bboxes for the first association
- first_det_inds = scores > self.obj_score_thrs['high']
- first_det_bboxes = bboxes[first_det_inds]
- first_det_labels = labels[first_det_inds]
- first_det_scores = scores[first_det_inds]
- first_det_ids = ids[first_det_inds]
- # get the detection bboxes for the second association
- second_det_inds = (~first_det_inds) & (
- scores > self.obj_score_thrs['low'])
- second_det_bboxes = bboxes[second_det_inds]
- second_det_labels = labels[second_det_inds]
- second_det_scores = scores[second_det_inds]
- second_det_ids = ids[second_det_inds]
- # 1. use Kalman Filter to predict current location
- for id in self.confirmed_ids:
- # track is lost in previous frame
- if self.tracks[id].frame_ids[-1] != frame_id - 1:
- self.tracks[id].mean[7] = 0
- (self.tracks[id].mean,
- self.tracks[id].covariance) = self.kf.predict(
- self.tracks[id].mean, self.tracks[id].covariance)
- # 2. first match
- first_match_track_inds, first_match_det_inds = self.assign_ids(
- self.confirmed_ids, first_det_bboxes, first_det_labels,
- first_det_scores, self.weight_iou_with_det_scores,
- self.match_iou_thrs['high'])
- # '-1' mean a detection box is not matched with tracklets in
- # previous frame
- valid = first_match_det_inds > -1
- first_det_ids[valid] = torch.tensor(
- self.confirmed_ids)[first_match_det_inds[valid]].to(labels)
- first_match_det_bboxes = first_det_bboxes[valid]
- first_match_det_labels = first_det_labels[valid]
- first_match_det_scores = first_det_scores[valid]
- first_match_det_ids = first_det_ids[valid]
- assert (first_match_det_ids > -1).all()
- first_unmatch_det_bboxes = first_det_bboxes[~valid]
- first_unmatch_det_labels = first_det_labels[~valid]
- first_unmatch_det_scores = first_det_scores[~valid]
- first_unmatch_det_ids = first_det_ids[~valid]
- assert (first_unmatch_det_ids == -1).all()
- # 3. use unmatched detection bboxes from the first match to match
- # the unconfirmed tracks
- (tentative_match_track_inds,
- tentative_match_det_inds) = self.assign_ids(
- self.unconfirmed_ids, first_unmatch_det_bboxes,
- first_unmatch_det_labels, first_unmatch_det_scores,
- self.weight_iou_with_det_scores,
- self.match_iou_thrs['tentative'])
- valid = tentative_match_det_inds > -1
- first_unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_ids)[
- tentative_match_det_inds[valid]].to(labels)
- # 4. second match for unmatched tracks from the first match
- first_unmatch_track_ids = []
- for i, id in enumerate(self.confirmed_ids):
- # tracklet is not matched in the first match
- case_1 = first_match_track_inds[i] == -1
- # tracklet is not lost in the previous frame
- case_2 = self.tracks[id].frame_ids[-1] == frame_id - 1
- if case_1 and case_2:
- first_unmatch_track_ids.append(id)
- second_match_track_inds, second_match_det_inds = self.assign_ids(
- first_unmatch_track_ids, second_det_bboxes, second_det_labels,
- second_det_scores, False, self.match_iou_thrs['low'])
- valid = second_match_det_inds > -1
- second_det_ids[valid] = torch.tensor(first_unmatch_track_ids)[
- second_match_det_inds[valid]].to(ids)
- # 5. gather all matched detection bboxes from step 2-4
- # we only keep matched detection bboxes in second match, which
- # means the id != -1
- valid = second_det_ids > -1
- bboxes = torch.cat(
- (first_match_det_bboxes, first_unmatch_det_bboxes), dim=0)
- bboxes = torch.cat((bboxes, second_det_bboxes[valid]), dim=0)
- labels = torch.cat(
- (first_match_det_labels, first_unmatch_det_labels), dim=0)
- labels = torch.cat((labels, second_det_labels[valid]), dim=0)
- scores = torch.cat(
- (first_match_det_scores, first_unmatch_det_scores), dim=0)
- scores = torch.cat((scores, second_det_scores[valid]), dim=0)
- ids = torch.cat((first_match_det_ids, first_unmatch_det_ids),
- dim=0)
- ids = torch.cat((ids, second_det_ids[valid]), dim=0)
- # 6. assign new ids
- new_track_inds = ids == -1
- ids[new_track_inds] = torch.arange(
- self.num_tracks,
- self.num_tracks + new_track_inds.sum()).to(labels)
- self.num_tracks += new_track_inds.sum()
- self.update(
- ids=ids,
- bboxes=bboxes,
- scores=scores,
- labels=labels,
- frame_ids=frame_id)
- # update pred_track_instances
- pred_track_instances = InstanceData()
- pred_track_instances.bboxes = bboxes
- pred_track_instances.labels = labels
- pred_track_instances.scores = scores
- pred_track_instances.instances_id = ids
- return pred_track_instances
|