mot_challenge_metric.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os
  3. import os.path as osp
  4. import shutil
  5. import tempfile
  6. from collections import defaultdict
  7. from typing import List, Optional, Union
  8. import numpy as np
  9. import torch
  10. try:
  11. import trackeval
  12. except ImportError:
  13. trackeval = None
  14. from mmengine.dist import (all_gather_object, barrier, broadcast,
  15. broadcast_object_list, get_dist_info,
  16. is_main_process)
  17. from mmengine.logging import MMLogger
  18. from mmdet.registry import METRICS, TASK_UTILS
  19. from .base_video_metric import BaseVideoMetric
  20. def get_tmpdir() -> str:
  21. """return the same tmpdir for all processes."""
  22. rank, world_size = get_dist_info()
  23. MAX_LEN = 512
  24. # 32 is whitespace
  25. dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8)
  26. if rank == 0:
  27. tmpdir = tempfile.mkdtemp()
  28. tmpdir = torch.tensor(bytearray(tmpdir.encode()), dtype=torch.uint8)
  29. dir_tensor[:len(tmpdir)] = tmpdir
  30. broadcast(dir_tensor, 0)
  31. tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
  32. return tmpdir
  33. @METRICS.register_module()
  34. class MOTChallengeMetric(BaseVideoMetric):
  35. """Evaluation metrics for MOT Challenge.
  36. Args:
  37. metric (str | list[str]): Metrics to be evaluated. Options are
  38. 'HOTA', 'CLEAR', 'Identity'.
  39. Defaults to ['HOTA', 'CLEAR', 'Identity'].
  40. outfile_prefix (str, optional): Path to save the formatted results.
  41. Defaults to None.
  42. track_iou_thr (float): IoU threshold for tracking evaluation.
  43. Defaults to 0.5.
  44. benchmark (str): Benchmark to be evaluated. Defaults to 'MOT17'.
  45. format_only (bool): If True, only formatting the results to the
  46. official format and not performing evaluation. Defaults to False.
  47. postprocess_tracklet_cfg (List[dict], optional): configs for tracklets
  48. postprocessing methods. `InterpolateTracklets` is supported.
  49. Defaults to []
  50. - InterpolateTracklets:
  51. - min_num_frames (int, optional): The minimum length of a
  52. track that will be interpolated. Defaults to 5.
  53. - max_num_frames (int, optional): The maximum disconnected
  54. length in a track. Defaults to 20.
  55. - use_gsi (bool, optional): Whether to use the GSI (Gaussian-
  56. smoothed interpolation) method. Defaults to False.
  57. - smooth_tau (int, optional): smoothing parameter in GSI.
  58. Defaults to 10.
  59. collect_device (str): Device name used for collecting results from
  60. different ranks during distributed training. Must be 'cpu' or
  61. 'gpu'. Defaults to 'cpu'.
  62. prefix (str, optional): The prefix that will be added in the metric
  63. names to disambiguate homonymous metrics of different evaluators.
  64. If prefix is not provided in the argument, self.default_prefix
  65. will be used instead. Default: None
  66. Returns:
  67. """
  68. TRACKER = 'default-tracker'
  69. allowed_metrics = ['HOTA', 'CLEAR', 'Identity']
  70. allowed_benchmarks = ['MOT15', 'MOT16', 'MOT17', 'MOT20', 'DanceTrack']
  71. default_prefix: Optional[str] = 'motchallenge-metric'
  72. def __init__(self,
  73. metric: Union[str, List[str]] = ['HOTA', 'CLEAR', 'Identity'],
  74. outfile_prefix: Optional[str] = None,
  75. track_iou_thr: float = 0.5,
  76. benchmark: str = 'MOT17',
  77. format_only: bool = False,
  78. use_postprocess: bool = False,
  79. postprocess_tracklet_cfg: Optional[List[dict]] = [],
  80. collect_device: str = 'cpu',
  81. prefix: Optional[str] = None) -> None:
  82. super().__init__(collect_device=collect_device, prefix=prefix)
  83. if trackeval is None:
  84. raise RuntimeError(
  85. 'trackeval is not installed,'
  86. 'please install it by: pip install'
  87. 'git+https://github.com/JonathonLuiten/TrackEval.git'
  88. 'trackeval need low version numpy, please install it'
  89. 'by: pip install -U numpy==1.23.5')
  90. if isinstance(metric, list):
  91. metrics = metric
  92. elif isinstance(metric, str):
  93. metrics = [metric]
  94. else:
  95. raise TypeError('metric must be a list or a str.')
  96. for metric in metrics:
  97. if metric not in self.allowed_metrics:
  98. raise KeyError(f'metric {metric} is not supported.')
  99. self.metrics = metrics
  100. self.format_only = format_only
  101. if self.format_only:
  102. assert outfile_prefix is not None, 'outfile_prefix must be not'
  103. 'None when format_only is True, otherwise the result files will'
  104. 'be saved to a temp directory which will be cleaned up at the end.'
  105. self.use_postprocess = use_postprocess
  106. self.postprocess_tracklet_cfg = postprocess_tracklet_cfg.copy()
  107. self.postprocess_tracklet_methods = [
  108. TASK_UTILS.build(cfg) for cfg in self.postprocess_tracklet_cfg
  109. ]
  110. assert benchmark in self.allowed_benchmarks
  111. self.benchmark = benchmark
  112. self.track_iou_thr = track_iou_thr
  113. self.tmp_dir = tempfile.TemporaryDirectory()
  114. self.tmp_dir.name = get_tmpdir()
  115. self.seq_info = defaultdict(
  116. lambda: dict(seq_length=-1, gt_tracks=[], pred_tracks=[]))
  117. self.gt_dir = self._get_gt_dir()
  118. self.pred_dir = self._get_pred_dir(outfile_prefix)
  119. self.seqmap = osp.join(self.pred_dir, 'videoseq.txt')
  120. with open(self.seqmap, 'w') as f:
  121. f.write('name\n')
  122. def __del__(self):
  123. # To avoid tmpdir being cleaned up too early, because in multiple
  124. # consecutive ValLoops, the value of `self.tmp_dir.name` is unchanged,
  125. # and calling `tmp_dir.cleanup()` in compute_metrics will cause errors.
  126. self.tmp_dir.cleanup()
  127. def _get_pred_dir(self, outfile_prefix):
  128. """Get directory to save the prediction results."""
  129. logger: MMLogger = MMLogger.get_current_instance()
  130. if outfile_prefix is None:
  131. outfile_prefix = self.tmp_dir.name
  132. else:
  133. if osp.exists(outfile_prefix) and is_main_process():
  134. logger.info('remove previous results.')
  135. shutil.rmtree(outfile_prefix)
  136. pred_dir = osp.join(outfile_prefix, self.TRACKER)
  137. os.makedirs(pred_dir, exist_ok=True)
  138. return pred_dir
  139. def _get_gt_dir(self):
  140. """Get directory to save the gt files."""
  141. output_dir = osp.join(self.tmp_dir.name, 'gt')
  142. os.makedirs(output_dir, exist_ok=True)
  143. return output_dir
  144. def transform_gt_and_pred(self, img_data_sample, video, frame_id):
  145. video = img_data_sample['img_path'].split(os.sep)[-3]
  146. # load gts
  147. if 'instances' in img_data_sample:
  148. gt_instances = img_data_sample['instances']
  149. gt_tracks = [
  150. np.array([
  151. frame_id + 1, gt_instances[i]['instance_id'],
  152. gt_instances[i]['bbox'][0], gt_instances[i]['bbox'][1],
  153. gt_instances[i]['bbox'][2] - gt_instances[i]['bbox'][0],
  154. gt_instances[i]['bbox'][3] - gt_instances[i]['bbox'][1],
  155. gt_instances[i]['mot_conf'],
  156. gt_instances[i]['category_id'],
  157. gt_instances[i]['visibility']
  158. ]) for i in range(len(gt_instances))
  159. ]
  160. self.seq_info[video]['gt_tracks'].extend(gt_tracks)
  161. # load predictions
  162. assert 'pred_track_instances' in img_data_sample
  163. if self.use_postprocess:
  164. pred_instances = img_data_sample['pred_track_instances']
  165. pred_tracks = [
  166. pred_instances['bboxes'][i]
  167. for i in range(len(pred_instances['bboxes']))
  168. ]
  169. else:
  170. pred_instances = img_data_sample['pred_track_instances']
  171. pred_tracks = [
  172. np.array([
  173. frame_id + 1, pred_instances['instances_id'][i].cpu(),
  174. pred_instances['bboxes'][i][0].cpu(),
  175. pred_instances['bboxes'][i][1].cpu(),
  176. (pred_instances['bboxes'][i][2] -
  177. pred_instances['bboxes'][i][0]).cpu(),
  178. (pred_instances['bboxes'][i][3] -
  179. pred_instances['bboxes'][i][1]).cpu(),
  180. pred_instances['scores'][i].cpu()
  181. ]) for i in range(len(pred_instances['instances_id']))
  182. ]
  183. self.seq_info[video]['pred_tracks'].extend(pred_tracks)
  184. def process_image(self, data_samples, video_len):
  185. img_data_sample = data_samples[0].to_dict()
  186. video = img_data_sample['img_path'].split(os.sep)[-3]
  187. frame_id = img_data_sample['frame_id']
  188. if self.seq_info[video]['seq_length'] == -1:
  189. self.seq_info[video]['seq_length'] = video_len
  190. self.transform_gt_and_pred(img_data_sample, video, frame_id)
  191. if frame_id == video_len - 1:
  192. # postprocessing
  193. if self.postprocess_tracklet_cfg:
  194. info = self.seq_info[video]
  195. pred_tracks = np.array(info['pred_tracks'])
  196. for postprocess_tracklet_methods in \
  197. self.postprocess_tracklet_methods:
  198. pred_tracks = postprocess_tracklet_methods\
  199. .forward(pred_tracks)
  200. info['pred_tracks'] = pred_tracks
  201. self._save_one_video_gts_preds(video)
  202. def process_video(self, data_samples):
  203. video_len = len(data_samples)
  204. for frame_id in range(video_len):
  205. img_data_sample = data_samples[frame_id].to_dict()
  206. # load basic info
  207. video = img_data_sample['img_path'].split(os.sep)[-3]
  208. if self.seq_info[video]['seq_length'] == -1:
  209. self.seq_info[video]['seq_length'] = video_len
  210. self.transform_gt_and_pred(img_data_sample, video, frame_id)
  211. if self.postprocess_tracklet_cfg:
  212. info = self.seq_info[video]
  213. pred_tracks = np.array(info['pred_tracks'])
  214. for postprocess_tracklet_methods in \
  215. self.postprocess_tracklet_methods:
  216. pred_tracks = postprocess_tracklet_methods \
  217. .forward(pred_tracks)
  218. info['pred_tracks'] = pred_tracks
  219. self._save_one_video_gts_preds(video)
  220. def _save_one_video_gts_preds(self, seq: str) -> None:
  221. """Save the gt and prediction results."""
  222. info = self.seq_info[seq]
  223. # save predictions
  224. pred_file = osp.join(self.pred_dir, seq + '.txt')
  225. pred_tracks = np.array(info['pred_tracks'])
  226. with open(pred_file, 'wt') as f:
  227. for tracks in pred_tracks:
  228. line = '%d,%d,%.3f,%.3f,%.3f,%.3f,%.3f,-1,-1,-1\n' % (
  229. tracks[0], tracks[1], tracks[2], tracks[3], tracks[4],
  230. tracks[5], tracks[6])
  231. f.writelines(line)
  232. info['pred_tracks'] = []
  233. # save gts
  234. if info['gt_tracks']:
  235. gt_file = osp.join(self.gt_dir, seq + '.txt')
  236. with open(gt_file, 'wt') as f:
  237. for tracks in info['gt_tracks']:
  238. line = '%d,%d,%d,%d,%d,%d,%d,%d,%.5f\n' % (
  239. tracks[0], tracks[1], tracks[2], tracks[3], tracks[4],
  240. tracks[5], tracks[6], tracks[7], tracks[8])
  241. f.writelines(line)
  242. info['gt_tracks'].clear()
  243. # save seq info
  244. with open(self.seqmap, 'a') as f:
  245. f.write(seq + '\n')
  246. f.close()
  247. def compute_metrics(self, results: list = None) -> dict:
  248. """Compute the metrics from processed results.
  249. Args:
  250. results (list): The processed results of each batch.
  251. Defaults to None.
  252. Returns:
  253. dict: The computed metrics. The keys are the names of the metrics,
  254. and the values are corresponding results.
  255. """
  256. logger: MMLogger = MMLogger.get_current_instance()
  257. # NOTICE: don't access `self.results` from the method.
  258. eval_results = dict()
  259. if self.format_only:
  260. return eval_results
  261. eval_config = trackeval.Evaluator.get_default_eval_config()
  262. # need to split out the tracker name
  263. # caused by the implementation of TrackEval
  264. pred_dir_tmp = self.pred_dir.rsplit(osp.sep, 1)[0]
  265. dataset_config = self.get_dataset_cfg(self.gt_dir, pred_dir_tmp)
  266. evaluator = trackeval.Evaluator(eval_config)
  267. dataset = [trackeval.datasets.MotChallenge2DBox(dataset_config)]
  268. metrics = [
  269. getattr(trackeval.metrics,
  270. metric)(dict(METRICS=[metric], THRESHOLD=0.5))
  271. for metric in self.metrics
  272. ]
  273. output_res, _ = evaluator.evaluate(dataset, metrics)
  274. output_res = output_res['MotChallenge2DBox'][
  275. self.TRACKER]['COMBINED_SEQ']['pedestrian']
  276. if 'HOTA' in self.metrics:
  277. logger.info('Evaluating HOTA Metrics...')
  278. eval_results['HOTA'] = np.average(output_res['HOTA']['HOTA'])
  279. eval_results['AssA'] = np.average(output_res['HOTA']['AssA'])
  280. eval_results['DetA'] = np.average(output_res['HOTA']['DetA'])
  281. if 'CLEAR' in self.metrics:
  282. logger.info('Evaluating CLEAR Metrics...')
  283. eval_results['MOTA'] = np.average(output_res['CLEAR']['MOTA'])
  284. eval_results['MOTP'] = np.average(output_res['CLEAR']['MOTP'])
  285. eval_results['IDSW'] = np.average(output_res['CLEAR']['IDSW'])
  286. eval_results['TP'] = np.average(output_res['CLEAR']['CLR_TP'])
  287. eval_results['FP'] = np.average(output_res['CLEAR']['CLR_FP'])
  288. eval_results['FN'] = np.average(output_res['CLEAR']['CLR_FN'])
  289. eval_results['Frag'] = np.average(output_res['CLEAR']['Frag'])
  290. eval_results['MT'] = np.average(output_res['CLEAR']['MT'])
  291. eval_results['ML'] = np.average(output_res['CLEAR']['ML'])
  292. if 'Identity' in self.metrics:
  293. logger.info('Evaluating Identity Metrics...')
  294. eval_results['IDF1'] = np.average(output_res['Identity']['IDF1'])
  295. eval_results['IDTP'] = np.average(output_res['Identity']['IDTP'])
  296. eval_results['IDFN'] = np.average(output_res['Identity']['IDFN'])
  297. eval_results['IDFP'] = np.average(output_res['Identity']['IDFP'])
  298. eval_results['IDP'] = np.average(output_res['Identity']['IDP'])
  299. eval_results['IDR'] = np.average(output_res['Identity']['IDR'])
  300. return eval_results
  301. def evaluate(self, size: int = 1) -> dict:
  302. """Evaluate the model performance of the whole dataset after processing
  303. all batches.
  304. Args:
  305. size (int): Length of the entire validation dataset.
  306. Defaults to None.
  307. Returns:
  308. dict: Evaluation metrics dict on the val dataset. The keys are the
  309. names of the metrics, and the values are corresponding results.
  310. """
  311. # wait for all processes to complete prediction.
  312. barrier()
  313. # gather seq_info and convert the list of dict to a dict.
  314. # convert self.seq_info to dict first to make it picklable.
  315. gathered_seq_info = all_gather_object(dict(self.seq_info))
  316. all_seq_info = dict()
  317. for _seq_info in gathered_seq_info:
  318. all_seq_info.update(_seq_info)
  319. self.seq_info = all_seq_info
  320. if is_main_process():
  321. _metrics = self.compute_metrics() # type: ignore
  322. # Add prefix to metric names
  323. if self.prefix:
  324. _metrics = {
  325. '/'.join((self.prefix, k)): v
  326. for k, v in _metrics.items()
  327. }
  328. metrics = [_metrics]
  329. else:
  330. metrics = [None] # type: ignore
  331. broadcast_object_list(metrics)
  332. # reset the results list
  333. self.results.clear()
  334. return metrics[0]
  335. def get_dataset_cfg(self, gt_folder: str, tracker_folder: str):
  336. """Get default configs for trackeval.datasets.MotChallenge2DBox.
  337. Args:
  338. gt_folder (str): the name of the GT folder
  339. tracker_folder (str): the name of the tracker folder
  340. Returns:
  341. Dataset Configs for MotChallenge2DBox.
  342. """
  343. dataset_config = dict(
  344. # Location of GT data
  345. GT_FOLDER=gt_folder,
  346. # Trackers location
  347. TRACKERS_FOLDER=tracker_folder,
  348. # Where to save eval results
  349. # (if None, same as TRACKERS_FOLDER)
  350. OUTPUT_FOLDER=None,
  351. # Use self.TRACKER as the default tracker
  352. TRACKERS_TO_EVAL=[self.TRACKER],
  353. # Option values: ['pedestrian']
  354. CLASSES_TO_EVAL=['pedestrian'],
  355. # Option Values: 'MOT15', 'MOT16', 'MOT17', 'MOT20', 'DanceTrack'
  356. BENCHMARK=self.benchmark,
  357. # Option Values: 'train', 'test'
  358. SPLIT_TO_EVAL='val' if self.benchmark == 'DanceTrack' else 'train',
  359. # Whether tracker input files are zipped
  360. INPUT_AS_ZIP=False,
  361. # Whether to print current config
  362. PRINT_CONFIG=True,
  363. # Whether to perform preprocessing
  364. # (never done for MOT15)
  365. DO_PREPROC=False if self.benchmark == 'MOT15' else True,
  366. # Tracker files are in
  367. # TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
  368. TRACKER_SUB_FOLDER='',
  369. # Output files are saved in
  370. # OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
  371. OUTPUT_SUB_FOLDER='',
  372. # Names of trackers to display
  373. # (if None: TRACKERS_TO_EVAL)
  374. TRACKER_DISPLAY_NAMES=None,
  375. # Where seqmaps are found
  376. # (if None: GT_FOLDER/seqmaps)
  377. SEQMAP_FOLDER=None,
  378. # Directly specify seqmap file
  379. # (if none use seqmap_folder/benchmark-split_to_eval)
  380. SEQMAP_FILE=self.seqmap,
  381. # If not None, specify sequences to eval
  382. # and their number of timesteps
  383. SEQ_INFO={
  384. seq: info['seq_length']
  385. for seq, info in self.seq_info.items()
  386. },
  387. # '{gt_folder}/{seq}.txt'
  388. GT_LOC_FORMAT='{gt_folder}/{seq}.txt',
  389. # If False, data is in GT_FOLDER/BENCHMARK-SPLIT_TO_EVAL/ and in
  390. # TRACKERS_FOLDER/BENCHMARK-SPLIT_TO_EVAL/tracker/
  391. # If True, the middle 'benchmark-split' folder is skipped for both.
  392. SKIP_SPLIT_FOL=True,
  393. )
  394. return dataset_config