mot_error_visualize.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. import os.path as osp
  5. import re
  6. import mmcv
  7. import motmetrics as mm
  8. import numpy as np
  9. import pandas as pd
  10. from mmengine import Config
  11. from mmengine.logging import print_log
  12. from mmengine.registry import init_default_scope
  13. from torch.utils.data import Dataset
  14. from mmdet.registry import DATASETS
  15. from mmdet.utils import imshow_mot_errors
  16. def parse_args():
  17. parser = argparse.ArgumentParser(
  18. description='visualize errors for multiple object tracking')
  19. parser.add_argument('config', help='path of the config file')
  20. parser.add_argument(
  21. '--result-dir', help='directory of the inference result')
  22. parser.add_argument(
  23. '--output-dir',
  24. help='directory where painted images or videos will be saved')
  25. parser.add_argument(
  26. '--show',
  27. action='store_true',
  28. help='whether to show the results on the fly')
  29. parser.add_argument(
  30. '--fps', type=int, default=3, help='FPS of the output video')
  31. parser.add_argument(
  32. '--backend',
  33. type=str,
  34. choices=['cv2', 'plt'],
  35. default='cv2',
  36. help='backend of visualization')
  37. args = parser.parse_args()
  38. return args
  39. def compare_res_gts(results_dir: str, dataset: Dataset, video_name: str):
  40. """Evaluate the results of the video.
  41. Args:
  42. results_dir (str): the directory of the MOT results.
  43. dataset (Dataset): MOT dataset of the video to be evaluated.
  44. video_name (str): Name of the video to be evaluated.
  45. Returns:
  46. tuple: (acc, res, gt), acc contains the results of MOT metrics,
  47. res is the results of inference and gt is the ground truth.
  48. """
  49. if 'half-train' in dataset.ann_file:
  50. gt_file = osp.join(dataset.data_prefix['img_path'],
  51. f'{video_name}/gt/gt_half-train.txt')
  52. gt = mm.io.loadtxt(gt_file)
  53. gt.index = gt.index.set_levels(
  54. pd.factorize(gt.index.levels[0])[0] + 1, level=0)
  55. elif 'half-val' in dataset.ann_file:
  56. gt_file = osp.join(dataset.data_prefix['img_path'],
  57. f'{video_name}/gt/gt_half-val.txt')
  58. gt = mm.io.loadtxt(gt_file)
  59. gt.index = gt.index.set_levels(
  60. pd.factorize(gt.index.levels[0])[0] + 1, level=0)
  61. else:
  62. gt_file = osp.join(dataset.data_prefix['img_path'],
  63. f'{video_name}/gt/gt.txt')
  64. gt = mm.io.loadtxt(gt_file)
  65. gt.index = gt.index.set_levels(
  66. pd.factorize(gt.index.levels[0])[0] + 1, level=0)
  67. res_file = osp.join(results_dir, f'{video_name}.txt')
  68. res = mm.io.loadtxt(res_file)
  69. ini_file = osp.join(dataset.data_prefix['img_path'],
  70. f'{video_name}/seqinfo.ini')
  71. if osp.exists(ini_file):
  72. acc, _ = mm.utils.CLEAR_MOT_M(gt, res, ini_file)
  73. else:
  74. acc = mm.utils.compare_to_groundtruth(gt, res)
  75. return acc, res, gt
  76. def main():
  77. args = parse_args()
  78. assert args.show or args.out_dir, \
  79. ('Please specify at least one operation (show the results '
  80. '/ save the results) with the argument "--show" or "--out-dir"')
  81. if args.out_dir is not None:
  82. os.makedirs(args.out_dir, exist_ok=True)
  83. print_log('This script visualizes the error for multiple object tracking. '
  84. 'By Default, the red bounding box denotes false positive, '
  85. 'the yellow bounding box denotes the false negative '
  86. 'and the blue bounding box denotes ID switch.')
  87. cfg = Config.fromfile(args.config)
  88. init_default_scope(cfg.get('default_scope', 'mmdet'))
  89. dataset = DATASETS.build(cfg.val_dataloader.dataset)
  90. # create index from frame_id to filename
  91. filenames_dict = dict()
  92. for i in range(len(dataset)):
  93. video_info = dataset.get_data_info(i)
  94. # the `data_info['file_name']` usually has the same format
  95. # with "MOT17-09-DPM/img1/000003.jpg"
  96. # split with both '\' and '/' to be compatible with different OS.
  97. for data_info in video_info['images']:
  98. split_path = re.split(r'[\\/]', data_info['file_name'])
  99. video_name = split_path[-3]
  100. frame_id = int(data_info['frame_id'] + 1)
  101. if video_name not in filenames_dict:
  102. filenames_dict[video_name] = dict()
  103. # the data_info['img_path'] usually has the same format
  104. # with `img_path_prefix + "MOT17-09-DPM/img1/000003.jpg"`
  105. filenames_dict[video_name][frame_id] = data_info['img_path']
  106. video_names = tuple(filenames_dict.keys())
  107. for video_name in video_names:
  108. print_log(f'Start processing video {video_name}')
  109. acc, res, gt = compare_res_gts(args.result_dir, dataset, video_name)
  110. frames_id_list = sorted(
  111. list(set(acc.mot_events.index.get_level_values(0))))
  112. for frame_id in frames_id_list:
  113. # events in the current frame
  114. events = acc.mot_events.xs(frame_id)
  115. cur_res = res.loc[frame_id] if frame_id in res.index else None
  116. cur_gt = gt.loc[frame_id] if frame_id in gt.index else None
  117. # path of image
  118. img = filenames_dict[video_name][frame_id]
  119. fps = events[events.Type == 'FP']
  120. fns = events[events.Type == 'MISS']
  121. idsws = events[events.Type == 'SWITCH']
  122. bboxes, ids, error_types = [], [], []
  123. for fp_index in fps.index:
  124. hid = events.loc[fp_index].HId
  125. bboxes.append([
  126. cur_res.loc[hid].X, cur_res.loc[hid].Y,
  127. cur_res.loc[hid].X + cur_res.loc[hid].Width,
  128. cur_res.loc[hid].Y + cur_res.loc[hid].Height,
  129. cur_res.loc[hid].Confidence
  130. ])
  131. ids.append(hid)
  132. # error_type = 0 denotes false positive error
  133. error_types.append(0)
  134. for fn_index in fns.index:
  135. oid = events.loc[fn_index].OId
  136. bboxes.append([
  137. cur_gt.loc[oid].X, cur_gt.loc[oid].Y,
  138. cur_gt.loc[oid].X + cur_gt.loc[oid].Width,
  139. cur_gt.loc[oid].Y + cur_gt.loc[oid].Height,
  140. cur_gt.loc[oid].Confidence
  141. ])
  142. ids.append(-1)
  143. # error_type = 1 denotes false negative error
  144. error_types.append(1)
  145. for idsw_index in idsws.index:
  146. hid = events.loc[idsw_index].HId
  147. bboxes.append([
  148. cur_res.loc[hid].X, cur_res.loc[hid].Y,
  149. cur_res.loc[hid].X + cur_res.loc[hid].Width,
  150. cur_res.loc[hid].Y + cur_res.loc[hid].Height,
  151. cur_res.loc[hid].Confidence
  152. ])
  153. ids.append(hid)
  154. # error_type = 2 denotes id switch
  155. error_types.append(2)
  156. if len(bboxes) == 0:
  157. bboxes = np.zeros((0, 5), dtype=np.float32)
  158. else:
  159. bboxes = np.asarray(bboxes, dtype=np.float32)
  160. ids = np.asarray(ids, dtype=np.int32)
  161. error_types = np.asarray(error_types, dtype=np.int32)
  162. imshow_mot_errors(
  163. img,
  164. bboxes,
  165. ids,
  166. error_types,
  167. show=args.show,
  168. out_file=osp.join(args.out_dir,
  169. f'{video_name}/{frame_id:06d}.jpg')
  170. if args.out_dir else None,
  171. backend=args.backend)
  172. print_log(f'Done! Visualization images are saved in '
  173. f'\'{args.out_dir}/{video_name}\'')
  174. mmcv.frames2video(
  175. f'{args.out_dir}/{video_name}',
  176. f'{args.out_dir}/{video_name}.mp4',
  177. fps=args.fps,
  178. fourcc='mp4v',
  179. start=frames_id_list[0],
  180. end=frames_id_list[-1],
  181. show_progress=False)
  182. print_log(
  183. f'Done! Visualization video is saved as '
  184. f'\'{args.out_dir}/{video_name}.mp4\' with a FPS of {args.fps}')
  185. if __name__ == '__main__':
  186. main()