123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import os
- import os.path as osp
- import re
- import mmcv
- import motmetrics as mm
- import numpy as np
- import pandas as pd
- from mmengine import Config
- from mmengine.logging import print_log
- from mmengine.registry import init_default_scope
- from torch.utils.data import Dataset
- from mmdet.registry import DATASETS
- from mmdet.utils import imshow_mot_errors
- def parse_args():
- parser = argparse.ArgumentParser(
- description='visualize errors for multiple object tracking')
- parser.add_argument('config', help='path of the config file')
- parser.add_argument(
- '--result-dir', help='directory of the inference result')
- parser.add_argument(
- '--output-dir',
- help='directory where painted images or videos will be saved')
- parser.add_argument(
- '--show',
- action='store_true',
- help='whether to show the results on the fly')
- parser.add_argument(
- '--fps', type=int, default=3, help='FPS of the output video')
- parser.add_argument(
- '--backend',
- type=str,
- choices=['cv2', 'plt'],
- default='cv2',
- help='backend of visualization')
- args = parser.parse_args()
- return args
- def compare_res_gts(results_dir: str, dataset: Dataset, video_name: str):
- """Evaluate the results of the video.
- Args:
- results_dir (str): the directory of the MOT results.
- dataset (Dataset): MOT dataset of the video to be evaluated.
- video_name (str): Name of the video to be evaluated.
- Returns:
- tuple: (acc, res, gt), acc contains the results of MOT metrics,
- res is the results of inference and gt is the ground truth.
- """
- if 'half-train' in dataset.ann_file:
- gt_file = osp.join(dataset.data_prefix['img_path'],
- f'{video_name}/gt/gt_half-train.txt')
- gt = mm.io.loadtxt(gt_file)
- gt.index = gt.index.set_levels(
- pd.factorize(gt.index.levels[0])[0] + 1, level=0)
- elif 'half-val' in dataset.ann_file:
- gt_file = osp.join(dataset.data_prefix['img_path'],
- f'{video_name}/gt/gt_half-val.txt')
- gt = mm.io.loadtxt(gt_file)
- gt.index = gt.index.set_levels(
- pd.factorize(gt.index.levels[0])[0] + 1, level=0)
- else:
- gt_file = osp.join(dataset.data_prefix['img_path'],
- f'{video_name}/gt/gt.txt')
- gt = mm.io.loadtxt(gt_file)
- gt.index = gt.index.set_levels(
- pd.factorize(gt.index.levels[0])[0] + 1, level=0)
- res_file = osp.join(results_dir, f'{video_name}.txt')
- res = mm.io.loadtxt(res_file)
- ini_file = osp.join(dataset.data_prefix['img_path'],
- f'{video_name}/seqinfo.ini')
- if osp.exists(ini_file):
- acc, _ = mm.utils.CLEAR_MOT_M(gt, res, ini_file)
- else:
- acc = mm.utils.compare_to_groundtruth(gt, res)
- return acc, res, gt
- def main():
- args = parse_args()
- assert args.show or args.out_dir, \
- ('Please specify at least one operation (show the results '
- '/ save the results) with the argument "--show" or "--out-dir"')
- if args.out_dir is not None:
- os.makedirs(args.out_dir, exist_ok=True)
- print_log('This script visualizes the error for multiple object tracking. '
- 'By Default, the red bounding box denotes false positive, '
- 'the yellow bounding box denotes the false negative '
- 'and the blue bounding box denotes ID switch.')
- cfg = Config.fromfile(args.config)
- init_default_scope(cfg.get('default_scope', 'mmdet'))
- dataset = DATASETS.build(cfg.val_dataloader.dataset)
- # create index from frame_id to filename
- filenames_dict = dict()
- for i in range(len(dataset)):
- video_info = dataset.get_data_info(i)
- # the `data_info['file_name']` usually has the same format
- # with "MOT17-09-DPM/img1/000003.jpg"
- # split with both '\' and '/' to be compatible with different OS.
- for data_info in video_info['images']:
- split_path = re.split(r'[\\/]', data_info['file_name'])
- video_name = split_path[-3]
- frame_id = int(data_info['frame_id'] + 1)
- if video_name not in filenames_dict:
- filenames_dict[video_name] = dict()
- # the data_info['img_path'] usually has the same format
- # with `img_path_prefix + "MOT17-09-DPM/img1/000003.jpg"`
- filenames_dict[video_name][frame_id] = data_info['img_path']
- video_names = tuple(filenames_dict.keys())
- for video_name in video_names:
- print_log(f'Start processing video {video_name}')
- acc, res, gt = compare_res_gts(args.result_dir, dataset, video_name)
- frames_id_list = sorted(
- list(set(acc.mot_events.index.get_level_values(0))))
- for frame_id in frames_id_list:
- # events in the current frame
- events = acc.mot_events.xs(frame_id)
- cur_res = res.loc[frame_id] if frame_id in res.index else None
- cur_gt = gt.loc[frame_id] if frame_id in gt.index else None
- # path of image
- img = filenames_dict[video_name][frame_id]
- fps = events[events.Type == 'FP']
- fns = events[events.Type == 'MISS']
- idsws = events[events.Type == 'SWITCH']
- bboxes, ids, error_types = [], [], []
- for fp_index in fps.index:
- hid = events.loc[fp_index].HId
- bboxes.append([
- cur_res.loc[hid].X, cur_res.loc[hid].Y,
- cur_res.loc[hid].X + cur_res.loc[hid].Width,
- cur_res.loc[hid].Y + cur_res.loc[hid].Height,
- cur_res.loc[hid].Confidence
- ])
- ids.append(hid)
- # error_type = 0 denotes false positive error
- error_types.append(0)
- for fn_index in fns.index:
- oid = events.loc[fn_index].OId
- bboxes.append([
- cur_gt.loc[oid].X, cur_gt.loc[oid].Y,
- cur_gt.loc[oid].X + cur_gt.loc[oid].Width,
- cur_gt.loc[oid].Y + cur_gt.loc[oid].Height,
- cur_gt.loc[oid].Confidence
- ])
- ids.append(-1)
- # error_type = 1 denotes false negative error
- error_types.append(1)
- for idsw_index in idsws.index:
- hid = events.loc[idsw_index].HId
- bboxes.append([
- cur_res.loc[hid].X, cur_res.loc[hid].Y,
- cur_res.loc[hid].X + cur_res.loc[hid].Width,
- cur_res.loc[hid].Y + cur_res.loc[hid].Height,
- cur_res.loc[hid].Confidence
- ])
- ids.append(hid)
- # error_type = 2 denotes id switch
- error_types.append(2)
- if len(bboxes) == 0:
- bboxes = np.zeros((0, 5), dtype=np.float32)
- else:
- bboxes = np.asarray(bboxes, dtype=np.float32)
- ids = np.asarray(ids, dtype=np.int32)
- error_types = np.asarray(error_types, dtype=np.int32)
- imshow_mot_errors(
- img,
- bboxes,
- ids,
- error_types,
- show=args.show,
- out_file=osp.join(args.out_dir,
- f'{video_name}/{frame_id:06d}.jpg')
- if args.out_dir else None,
- backend=args.backend)
- print_log(f'Done! Visualization images are saved in '
- f'\'{args.out_dir}/{video_name}\'')
- mmcv.frames2video(
- f'{args.out_dir}/{video_name}',
- f'{args.out_dir}/{video_name}.mp4',
- fps=args.fps,
- fourcc='mp4v',
- start=frames_id_list[0],
- end=frames_id_list[-1],
- show_progress=False)
- print_log(
- f'Done! Visualization video is saved as '
- f'\'{args.out_dir}/{video_name}.mp4\' with a FPS of {args.fps}')
- if __name__ == '__main__':
- main()
|