quasi_dense_tracker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from mmengine.structures import InstanceData
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from mmdet.structures import TrackDataSample
  9. from mmdet.structures.bbox import bbox_overlaps
  10. from .base_tracker import BaseTracker
  11. @MODELS.register_module()
  12. class QuasiDenseTracker(BaseTracker):
  13. """Tracker for Quasi-Dense Tracking.
  14. Args:
  15. init_score_thr (float): The cls_score threshold to
  16. initialize a new tracklet. Defaults to 0.8.
  17. obj_score_thr (float): The cls_score threshold to
  18. update a tracked tracklet. Defaults to 0.5.
  19. match_score_thr (float): The match threshold. Defaults to 0.5.
  20. memo_tracklet_frames (int): The most frames in a tracklet memory.
  21. Defaults to 10.
  22. memo_backdrop_frames (int): The most frames in the backdrops.
  23. Defaults to 1.
  24. memo_momentum (float): The momentum value for embeds updating.
  25. Defaults to 0.8.
  26. nms_conf_thr (float): The nms threshold for confidence.
  27. Defaults to 0.5.
  28. nms_backdrop_iou_thr (float): The nms threshold for backdrop IoU.
  29. Defaults to 0.3.
  30. nms_class_iou_thr (float): The nms threshold for class IoU.
  31. Defaults to 0.7.
  32. with_cats (bool): Whether to track with the same category.
  33. Defaults to True.
  34. match_metric (str): The match metric. Defaults to 'bisoftmax'.
  35. """
  36. def __init__(self,
  37. init_score_thr: float = 0.8,
  38. obj_score_thr: float = 0.5,
  39. match_score_thr: float = 0.5,
  40. memo_tracklet_frames: int = 10,
  41. memo_backdrop_frames: int = 1,
  42. memo_momentum: float = 0.8,
  43. nms_conf_thr: float = 0.5,
  44. nms_backdrop_iou_thr: float = 0.3,
  45. nms_class_iou_thr: float = 0.7,
  46. with_cats: bool = True,
  47. match_metric: str = 'bisoftmax',
  48. **kwargs):
  49. super().__init__(**kwargs)
  50. assert 0 <= memo_momentum <= 1.0
  51. assert memo_tracklet_frames >= 0
  52. assert memo_backdrop_frames >= 0
  53. self.init_score_thr = init_score_thr
  54. self.obj_score_thr = obj_score_thr
  55. self.match_score_thr = match_score_thr
  56. self.memo_tracklet_frames = memo_tracklet_frames
  57. self.memo_backdrop_frames = memo_backdrop_frames
  58. self.memo_momentum = memo_momentum
  59. self.nms_conf_thr = nms_conf_thr
  60. self.nms_backdrop_iou_thr = nms_backdrop_iou_thr
  61. self.nms_class_iou_thr = nms_class_iou_thr
  62. self.with_cats = with_cats
  63. assert match_metric in ['bisoftmax', 'softmax', 'cosine']
  64. self.match_metric = match_metric
  65. self.num_tracks = 0
  66. self.tracks = dict()
  67. self.backdrops = []
  68. def reset(self):
  69. """Reset the buffer of the tracker."""
  70. self.num_tracks = 0
  71. self.tracks = dict()
  72. self.backdrops = []
  73. def update(self, ids: Tensor, bboxes: Tensor, embeds: Tensor,
  74. labels: Tensor, scores: Tensor, frame_id: int) -> None:
  75. """Tracking forward function.
  76. Args:
  77. ids (Tensor): of shape(N, ).
  78. bboxes (Tensor): of shape (N, 5).
  79. embeds (Tensor): of shape (N, 256).
  80. labels (Tensor): of shape (N, ).
  81. scores (Tensor): of shape (N, ).
  82. frame_id (int): The id of current frame, 0-index.
  83. """
  84. tracklet_inds = ids > -1
  85. for id, bbox, embed, label, score in zip(ids[tracklet_inds],
  86. bboxes[tracklet_inds],
  87. embeds[tracklet_inds],
  88. labels[tracklet_inds],
  89. scores[tracklet_inds]):
  90. id = int(id)
  91. # update the tracked ones and initialize new tracks
  92. if id in self.tracks.keys():
  93. velocity = (bbox - self.tracks[id]['bbox']) / (
  94. frame_id - self.tracks[id]['last_frame'])
  95. self.tracks[id]['bbox'] = bbox
  96. self.tracks[id]['embed'] = (
  97. 1 - self.memo_momentum
  98. ) * self.tracks[id]['embed'] + self.memo_momentum * embed
  99. self.tracks[id]['last_frame'] = frame_id
  100. self.tracks[id]['label'] = label
  101. self.tracks[id]['score'] = score
  102. self.tracks[id]['velocity'] = (
  103. self.tracks[id]['velocity'] * self.tracks[id]['acc_frame']
  104. + velocity) / (
  105. self.tracks[id]['acc_frame'] + 1)
  106. self.tracks[id]['acc_frame'] += 1
  107. else:
  108. self.tracks[id] = dict(
  109. bbox=bbox,
  110. embed=embed,
  111. label=label,
  112. score=score,
  113. last_frame=frame_id,
  114. velocity=torch.zeros_like(bbox),
  115. acc_frame=0)
  116. # backdrop update according to IoU
  117. backdrop_inds = torch.nonzero(ids == -1, as_tuple=False).squeeze(1)
  118. ious = bbox_overlaps(bboxes[backdrop_inds], bboxes)
  119. for i, ind in enumerate(backdrop_inds):
  120. if (ious[i, :ind] > self.nms_backdrop_iou_thr).any():
  121. backdrop_inds[i] = -1
  122. backdrop_inds = backdrop_inds[backdrop_inds > -1]
  123. # old backdrops would be removed at first
  124. self.backdrops.insert(
  125. 0,
  126. dict(
  127. bboxes=bboxes[backdrop_inds],
  128. embeds=embeds[backdrop_inds],
  129. labels=labels[backdrop_inds]))
  130. # pop memo
  131. invalid_ids = []
  132. for k, v in self.tracks.items():
  133. if frame_id - v['last_frame'] >= self.memo_tracklet_frames:
  134. invalid_ids.append(k)
  135. for invalid_id in invalid_ids:
  136. self.tracks.pop(invalid_id)
  137. if len(self.backdrops) > self.memo_backdrop_frames:
  138. self.backdrops.pop()
  139. @property
  140. def memo(self) -> Tuple[Tensor, ...]:
  141. """Get tracks memory."""
  142. memo_embeds = []
  143. memo_ids = []
  144. memo_bboxes = []
  145. memo_labels = []
  146. # velocity of tracks
  147. memo_vs = []
  148. # get tracks
  149. for k, v in self.tracks.items():
  150. memo_bboxes.append(v['bbox'][None, :])
  151. memo_embeds.append(v['embed'][None, :])
  152. memo_ids.append(k)
  153. memo_labels.append(v['label'].view(1, 1))
  154. memo_vs.append(v['velocity'][None, :])
  155. memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1)
  156. # get backdrops
  157. for backdrop in self.backdrops:
  158. backdrop_ids = torch.full((1, backdrop['embeds'].size(0)),
  159. -1,
  160. dtype=torch.long)
  161. backdrop_vs = torch.zeros_like(backdrop['bboxes'])
  162. memo_bboxes.append(backdrop['bboxes'])
  163. memo_embeds.append(backdrop['embeds'])
  164. memo_ids = torch.cat([memo_ids, backdrop_ids], dim=1)
  165. memo_labels.append(backdrop['labels'][:, None])
  166. memo_vs.append(backdrop_vs)
  167. memo_bboxes = torch.cat(memo_bboxes, dim=0)
  168. memo_embeds = torch.cat(memo_embeds, dim=0)
  169. memo_labels = torch.cat(memo_labels, dim=0).squeeze(1)
  170. memo_vs = torch.cat(memo_vs, dim=0)
  171. return memo_bboxes, memo_labels, memo_embeds, memo_ids.squeeze(
  172. 0), memo_vs
  173. def track(self,
  174. model: torch.nn.Module,
  175. img: torch.Tensor,
  176. feats: List[torch.Tensor],
  177. data_sample: TrackDataSample,
  178. rescale=True,
  179. **kwargs) -> InstanceData:
  180. """Tracking forward function.
  181. Args:
  182. model (nn.Module): MOT model.
  183. img (Tensor): of shape (T, C, H, W) encoding input image.
  184. Typically these should be mean centered and std scaled.
  185. The T denotes the number of key images and usually is 1 in
  186. QDTrack method.
  187. feats (list[Tensor]): Multi level feature maps of `img`.
  188. data_sample (:obj:`TrackDataSample`): The data sample.
  189. It includes information such as `pred_instances`.
  190. rescale (bool, optional): If True, the bounding boxes should be
  191. rescaled to fit the original scale of the image. Defaults to
  192. True.
  193. Returns:
  194. :obj:`InstanceData`: Tracking results of the input images.
  195. Each InstanceData usually contains ``bboxes``, ``labels``,
  196. ``scores`` and ``instances_id``.
  197. """
  198. metainfo = data_sample.metainfo
  199. bboxes = data_sample.pred_instances.bboxes
  200. labels = data_sample.pred_instances.labels
  201. scores = data_sample.pred_instances.scores
  202. frame_id = metainfo.get('frame_id', -1)
  203. # create pred_track_instances
  204. pred_track_instances = InstanceData()
  205. # return zero bboxes if there is no track targets
  206. if bboxes.shape[0] == 0:
  207. ids = torch.zeros_like(labels)
  208. pred_track_instances = data_sample.pred_instances.clone()
  209. pred_track_instances.instances_id = ids
  210. return pred_track_instances
  211. # get track feats
  212. rescaled_bboxes = bboxes.clone()
  213. if rescale:
  214. scale_factor = rescaled_bboxes.new_tensor(
  215. metainfo['scale_factor']).repeat((1, 2))
  216. rescaled_bboxes = rescaled_bboxes * scale_factor
  217. track_feats = model.track_head.predict(feats, [rescaled_bboxes])
  218. # sort according to the object_score
  219. _, inds = scores.sort(descending=True)
  220. bboxes = bboxes[inds]
  221. scores = scores[inds]
  222. labels = labels[inds]
  223. embeds = track_feats[inds, :]
  224. # duplicate removal for potential backdrops and cross classes
  225. valids = bboxes.new_ones((bboxes.size(0)))
  226. ious = bbox_overlaps(bboxes, bboxes)
  227. for i in range(1, bboxes.size(0)):
  228. thr = self.nms_backdrop_iou_thr if scores[
  229. i] < self.obj_score_thr else self.nms_class_iou_thr
  230. if (ious[i, :i] > thr).any():
  231. valids[i] = 0
  232. valids = valids == 1
  233. bboxes = bboxes[valids]
  234. scores = scores[valids]
  235. labels = labels[valids]
  236. embeds = embeds[valids, :]
  237. # init ids container
  238. ids = torch.full((bboxes.size(0), ), -1, dtype=torch.long)
  239. # match if buffer is not empty
  240. if bboxes.size(0) > 0 and not self.empty:
  241. (memo_bboxes, memo_labels, memo_embeds, memo_ids,
  242. memo_vs) = self.memo
  243. if self.match_metric == 'bisoftmax':
  244. feats = torch.mm(embeds, memo_embeds.t())
  245. d2t_scores = feats.softmax(dim=1)
  246. t2d_scores = feats.softmax(dim=0)
  247. match_scores = (d2t_scores + t2d_scores) / 2
  248. elif self.match_metric == 'softmax':
  249. feats = torch.mm(embeds, memo_embeds.t())
  250. match_scores = feats.softmax(dim=1)
  251. elif self.match_metric == 'cosine':
  252. match_scores = torch.mm(
  253. F.normalize(embeds, p=2, dim=1),
  254. F.normalize(memo_embeds, p=2, dim=1).t())
  255. else:
  256. raise NotImplementedError
  257. # track with the same category
  258. if self.with_cats:
  259. cat_same = labels.view(-1, 1) == memo_labels.view(1, -1)
  260. match_scores *= cat_same.float().to(match_scores.device)
  261. # track according to match_scores
  262. for i in range(bboxes.size(0)):
  263. conf, memo_ind = torch.max(match_scores[i, :], dim=0)
  264. id = memo_ids[memo_ind]
  265. if conf > self.match_score_thr:
  266. if id > -1:
  267. # keep bboxes with high object score
  268. # and remove background bboxes
  269. if scores[i] > self.obj_score_thr:
  270. ids[i] = id
  271. match_scores[:i, memo_ind] = 0
  272. match_scores[i + 1:, memo_ind] = 0
  273. else:
  274. if conf > self.nms_conf_thr:
  275. ids[i] = -2
  276. # initialize new tracks
  277. new_inds = (ids == -1) & (scores > self.init_score_thr).cpu()
  278. num_news = new_inds.sum()
  279. ids[new_inds] = torch.arange(
  280. self.num_tracks, self.num_tracks + num_news, dtype=torch.long)
  281. self.num_tracks += num_news
  282. self.update(ids, bboxes, embeds, labels, scores, frame_id)
  283. tracklet_inds = ids > -1
  284. # update pred_track_instances
  285. pred_track_instances.bboxes = bboxes[tracklet_inds]
  286. pred_track_instances.labels = labels[tracklet_inds]
  287. pred_track_instances.scores = scores[tracklet_inds]
  288. pred_track_instances.instances_id = ids[tracklet_inds]
  289. return pred_track_instances