mot_demo.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os
  3. import os.path as osp
  4. import tempfile
  5. from argparse import ArgumentParser
  6. import mmcv
  7. import mmengine
  8. from mmengine.registry import init_default_scope
  9. from mmdet.apis import inference_mot, init_track_model
  10. from mmdet.registry import VISUALIZERS
  11. IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png')
  12. def parse_args():
  13. parser = ArgumentParser()
  14. parser.add_argument(
  15. 'inputs', type=str, help='Input image file or folder path.')
  16. parser.add_argument('config', help='config file')
  17. parser.add_argument('--checkpoint', help='checkpoint file')
  18. parser.add_argument('--detector', help='det checkpoint file')
  19. parser.add_argument('--reid', help='reid checkpoint file')
  20. parser.add_argument(
  21. '--device', default='cuda:0', help='device used for inference')
  22. parser.add_argument(
  23. '--score-thr',
  24. type=float,
  25. default=0.0,
  26. help='The threshold of score to filter bboxes.')
  27. parser.add_argument(
  28. '--out', help='output video file (mp4 format) or folder')
  29. parser.add_argument(
  30. '--show',
  31. action='store_true',
  32. help='whether show the results on the fly')
  33. parser.add_argument('--fps', help='FPS of the output video')
  34. args = parser.parse_args()
  35. return args
  36. def main(args):
  37. assert args.out or args.show
  38. # load images
  39. if osp.isdir(args.inputs):
  40. imgs = sorted(
  41. filter(lambda x: x.endswith(IMG_EXTENSIONS),
  42. os.listdir(args.inputs)),
  43. key=lambda x: int(x.split('.')[0]))
  44. in_video = False
  45. else:
  46. imgs = mmcv.VideoReader(args.inputs)
  47. in_video = True
  48. # define output
  49. out_video = False
  50. if args.out is not None:
  51. if args.out.endswith('.mp4'):
  52. out_video = True
  53. out_dir = tempfile.TemporaryDirectory()
  54. out_path = out_dir.name
  55. _out = args.out.rsplit(os.sep, 1)
  56. if len(_out) > 1:
  57. os.makedirs(_out[0], exist_ok=True)
  58. else:
  59. out_path = args.out
  60. os.makedirs(out_path, exist_ok=True)
  61. fps = args.fps
  62. if args.show or out_video:
  63. if fps is None and in_video:
  64. fps = imgs.fps
  65. if not fps:
  66. raise ValueError('Please set the FPS for the output video.')
  67. fps = int(fps)
  68. init_default_scope('mmdet')
  69. # build the model from a config file and a checkpoint file
  70. model = init_track_model(
  71. args.config,
  72. args.checkpoint,
  73. args.detector,
  74. args.reid,
  75. device=args.device)
  76. # build the visualizer
  77. visualizer = VISUALIZERS.build(model.cfg.visualizer)
  78. visualizer.dataset_meta = model.dataset_meta
  79. prog_bar = mmengine.ProgressBar(len(imgs))
  80. # test and show/save the images
  81. for i, img in enumerate(imgs):
  82. if isinstance(img, str):
  83. img_path = osp.join(args.inputs, img)
  84. img = mmcv.imread(img_path)
  85. # result [TrackDataSample]
  86. result = inference_mot(model, img, frame_id=i, video_len=len(imgs))
  87. if args.out is not None:
  88. if in_video or out_video:
  89. out_file = osp.join(out_path, f'{i:06d}.jpg')
  90. else:
  91. out_file = osp.join(out_path, img.rsplit(os.sep, 1)[-1])
  92. else:
  93. out_file = None
  94. # show the results
  95. visualizer.add_datasample(
  96. 'mot',
  97. img[..., ::-1],
  98. data_sample=result[0],
  99. show=args.show,
  100. draw_gt=False,
  101. out_file=out_file,
  102. wait_time=float(1 / int(fps)) if fps else 0,
  103. pred_score_thr=args.score_thr,
  104. step=i)
  105. prog_bar.update()
  106. if args.out and out_video:
  107. print(f'making the output video at {args.out} with a FPS of {fps}')
  108. mmcv.frames2video(out_path, args.out, fps=fps, fourcc='mp4v')
  109. out_dir.cleanup()
  110. if __name__ == '__main__':
  111. args = parse_args()
  112. main(args)