demo.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from argparse import ArgumentParser
  3. from mmengine.config import Config
  4. from mmengine.logging import print_log
  5. from mmdet.apis import DetInferencer
  6. from projects.XDecoder.xdecoder.inference import (
  7. ImageCaptionInferencer, RefImageCaptionInferencer,
  8. TextToImageRegionRetrievalInferencer)
  9. TASKINFOS = {
  10. 'semseg': DetInferencer,
  11. 'ref-seg': DetInferencer,
  12. 'instance': DetInferencer,
  13. 'panoptic': DetInferencer,
  14. 'caption': ImageCaptionInferencer,
  15. 'ref-caption': RefImageCaptionInferencer,
  16. 'retrieval': TextToImageRegionRetrievalInferencer,
  17. }
  18. def parse_args():
  19. parser = ArgumentParser()
  20. parser.add_argument(
  21. 'inputs', type=str, help='Input image file or folder path.')
  22. parser.add_argument('model', type=str, help='Config file name')
  23. parser.add_argument('--weights', help='Checkpoint file')
  24. parser.add_argument('--texts', help='text prompt')
  25. parser.add_argument(
  26. '--out-dir',
  27. type=str,
  28. default='outputs',
  29. help='Output directory of images or prediction results.')
  30. parser.add_argument(
  31. '--device', default='cuda:0', help='Device used for inference')
  32. parser.add_argument(
  33. '--show',
  34. action='store_true',
  35. help='Display the image in a popup window.')
  36. parser.add_argument(
  37. '--no-save-vis',
  38. action='store_true',
  39. help='Do not save detection vis results')
  40. parser.add_argument(
  41. '--palette',
  42. default='none',
  43. choices=['ade20k', 'coco', 'voc', 'citys', 'random', 'none'],
  44. help='Color palette used for visualization')
  45. # only for instance segmentation
  46. parser.add_argument(
  47. '--pred-score-thr',
  48. type=float,
  49. default=0.5,
  50. help='bbox score threshold')
  51. # only for panoptic segmentation
  52. parser.add_argument(
  53. '--stuff-texts',
  54. help='text prompt for stuff name in panoptic segmentation')
  55. call_args = vars(parser.parse_args())
  56. if call_args['no_save_vis']:
  57. call_args['out_dir'] = ''
  58. init_kws = ['model', 'weights', 'device', 'palette']
  59. init_args = {}
  60. for init_kw in init_kws:
  61. init_args[init_kw] = call_args.pop(init_kw)
  62. return init_args, call_args
  63. def main():
  64. init_args, call_args = parse_args()
  65. cfg = Config.fromfile(init_args['model'])
  66. task = cfg.model.head.task
  67. assert task in TASKINFOS
  68. inferencer = TASKINFOS[task](**init_args)
  69. if task != 'caption':
  70. assert call_args[
  71. 'texts'] is not None, f'text prompts is required for {task}'
  72. if task != 'panoptic':
  73. call_args.pop('stuff_texts')
  74. else:
  75. call_args.pop('texts')
  76. call_args.pop('stuff_texts')
  77. inferencer(**call_args)
  78. if call_args['out_dir'] != '' and not call_args['no_save_vis']:
  79. print_log(f'results have been saved at {call_args["out_dir"]}')
  80. if __name__ == '__main__':
  81. main()