launch.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. # Modified from MMPretrain
  2. import gradio as gr
  3. import torch
  4. from mmengine.logging import MMLogger
  5. from mmdet.apis import DetInferencer
  6. from projects.XDecoder.xdecoder.inference import (
  7. ImageCaptionInferencer, RefImageCaptionInferencer,
  8. TextToImageRegionRetrievalInferencer)
  9. logger = MMLogger('mmdetection', logger_name='mmdet')
  10. if torch.cuda.is_available():
  11. gpus = [
  12. torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
  13. ]
  14. logger.info(f'Available GPUs: {len(gpus)}')
  15. else:
  16. gpus = None
  17. logger.info('No available GPU.')
  18. def get_free_device():
  19. if gpus is None:
  20. return torch.device('cpu')
  21. if hasattr(torch.cuda, 'mem_get_info'):
  22. free = [torch.cuda.mem_get_info(gpu)[0] for gpu in gpus]
  23. select = max(zip(free, range(len(free))))[1]
  24. else:
  25. import random
  26. select = random.randint(0, len(gpus) - 1)
  27. return gpus[select]
  28. class ObjectDetectionTab:
  29. model_list = [
  30. 'retinanet_r50-caffe_fpn_1x_coco',
  31. 'faster-rcnn_r50-caffe_fpn_1x_coco',
  32. 'dino-5scale_swin-l_8xb2-12e_coco.py',
  33. ]
  34. def __init__(self) -> None:
  35. self.create_ui()
  36. def create_ui(self):
  37. with gr.Row():
  38. with gr.Column():
  39. select_model = gr.Dropdown(
  40. label='Choose a model',
  41. elem_id='od_models',
  42. elem_classes='select_model',
  43. choices=self.model_list,
  44. value=self.model_list[0],
  45. )
  46. with gr.Column():
  47. image_input = gr.Image(
  48. label='Image',
  49. source='upload',
  50. elem_classes='input_image',
  51. type='filepath',
  52. interactive=True,
  53. tool='editor',
  54. )
  55. output = gr.Image(
  56. label='Result',
  57. source='upload',
  58. interactive=False,
  59. elem_classes='result',
  60. )
  61. run_button = gr.Button(
  62. 'Run',
  63. elem_classes='run_button',
  64. )
  65. run_button.click(
  66. self.inference,
  67. inputs=[select_model, image_input],
  68. outputs=output,
  69. )
  70. with gr.Row():
  71. example_images = gr.Dataset(
  72. components=[image_input], samples=[['demo/demo.jpg']])
  73. example_images.click(
  74. fn=lambda x: gr.Image.update(value=x[0]),
  75. inputs=example_images,
  76. outputs=image_input)
  77. def inference(self, model, image):
  78. det_inferencer = DetInferencer(
  79. model, scope='mmdet', device=get_free_device())
  80. results_dict = det_inferencer(image, return_vis=True, no_save_vis=True)
  81. vis = results_dict['visualization'][0]
  82. return vis
  83. class InstanceSegTab(ObjectDetectionTab):
  84. model_list = ['mask-rcnn_r50-caffe_fpn_1x_coco', 'solov2_r50_fpn_1x_coco']
  85. class PanopticSegTab(ObjectDetectionTab):
  86. model_list = [
  87. 'panoptic_fpn_r50_fpn_1x_coco',
  88. 'mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco-panoptic'
  89. ]
  90. class OpenVocabObjectDetectionTab:
  91. model_list = ['glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365']
  92. def __init__(self) -> None:
  93. self.create_ui()
  94. def create_ui(self):
  95. with gr.Row():
  96. with gr.Column():
  97. select_model = gr.Dropdown(
  98. label='Choose a model',
  99. elem_id='od_models',
  100. elem_classes='select_model',
  101. choices=self.model_list,
  102. value=self.model_list[0],
  103. )
  104. with gr.Column():
  105. image_input = gr.Image(
  106. label='Image',
  107. source='upload',
  108. elem_classes='input_image',
  109. type='filepath',
  110. interactive=True,
  111. tool='editor',
  112. )
  113. text_input = gr.Textbox(
  114. label='text prompt',
  115. elem_classes='input_text',
  116. interactive=True,
  117. )
  118. output = gr.Image(
  119. label='Result',
  120. source='upload',
  121. interactive=False,
  122. elem_classes='result',
  123. )
  124. run_button = gr.Button(
  125. 'Run',
  126. elem_classes='run_button',
  127. )
  128. run_button.click(
  129. self.inference,
  130. inputs=[select_model, image_input, text_input],
  131. outputs=output,
  132. )
  133. with gr.Row():
  134. example_images = gr.Dataset(
  135. components=[image_input, text_input],
  136. samples=[['demo/demo.jpg', 'bench . car .']])
  137. example_images.click(
  138. fn=self.update,
  139. inputs=example_images,
  140. outputs=[image_input, text_input])
  141. def update(self, example):
  142. return gr.Image.update(value=example[0]), gr.Textbox.update(
  143. value=example[1])
  144. def inference(self, model, image, text):
  145. det_inferencer = DetInferencer(
  146. model, scope='mmdet', device=get_free_device())
  147. results_dict = det_inferencer(
  148. image,
  149. texts=text,
  150. custom_entities=True,
  151. pred_score_thr=0.5,
  152. return_vis=True,
  153. no_save_vis=True)
  154. vis = results_dict['visualization'][0]
  155. return vis
  156. class GroundingDetectionTab(OpenVocabObjectDetectionTab):
  157. model_list = ['glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365']
  158. def create_ui(self):
  159. with gr.Row():
  160. with gr.Column():
  161. select_model = gr.Dropdown(
  162. label='Choose a model',
  163. elem_id='od_models',
  164. elem_classes='select_model',
  165. choices=self.model_list,
  166. value=self.model_list[0],
  167. )
  168. with gr.Column():
  169. image_input = gr.Image(
  170. label='Image',
  171. source='upload',
  172. elem_classes='input_image',
  173. type='filepath',
  174. interactive=True,
  175. tool='editor',
  176. )
  177. text_input = gr.Textbox(
  178. label='text prompt',
  179. elem_classes='input_text',
  180. interactive=True,
  181. )
  182. output = gr.Image(
  183. label='Result',
  184. source='upload',
  185. interactive=False,
  186. elem_classes='result',
  187. )
  188. run_button = gr.Button(
  189. 'Run',
  190. elem_classes='run_button',
  191. )
  192. run_button.click(
  193. self.inference,
  194. inputs=[select_model, image_input, text_input],
  195. outputs=output,
  196. )
  197. with gr.Row():
  198. example_images = gr.Dataset(
  199. components=[image_input, text_input],
  200. samples=[['demo/demo.jpg', 'There are a lot of cars here.']])
  201. example_images.click(
  202. fn=self.update,
  203. inputs=example_images,
  204. outputs=[image_input, text_input])
  205. def inference(self, model, image, text):
  206. det_inferencer = DetInferencer(
  207. model, scope='mmdet', device=get_free_device())
  208. results_dict = det_inferencer(
  209. image,
  210. texts=text,
  211. custom_entities=False,
  212. pred_score_thr=0.5,
  213. return_vis=True,
  214. no_save_vis=True)
  215. vis = results_dict['visualization'][0]
  216. return vis
  217. class OpenVocabInstanceSegTab(OpenVocabObjectDetectionTab):
  218. model_list = ['xdecoder-tiny']
  219. model_info = {
  220. 'xdecoder-tiny': {
  221. 'model':
  222. 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py', # noqa
  223. 'weights':
  224. 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa
  225. }
  226. }
  227. def inference(self, model, image, text):
  228. det_inferencer = DetInferencer(
  229. **self.model_info[model], scope='mmdet', device=get_free_device())
  230. results_dict = det_inferencer(
  231. image, texts=text, return_vis=True, no_save_vis=True)
  232. vis = results_dict['visualization'][0]
  233. return vis
  234. class OpenVocabPanopticSegTab(OpenVocabObjectDetectionTab):
  235. model_list = ['xdecoder-tiny']
  236. model_info = {
  237. 'xdecoder-tiny': {
  238. 'model':
  239. 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py', # noqa
  240. 'weights':
  241. 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa
  242. }
  243. }
  244. def create_ui(self):
  245. with gr.Row():
  246. with gr.Column():
  247. select_model = gr.Dropdown(
  248. label='Choose a model',
  249. elem_id='od_models',
  250. elem_classes='select_model',
  251. choices=self.model_list,
  252. value=self.model_list[0],
  253. )
  254. with gr.Column():
  255. image_input = gr.Image(
  256. label='Image',
  257. source='upload',
  258. elem_classes='input_image',
  259. type='filepath',
  260. interactive=True,
  261. tool='editor',
  262. )
  263. text_input = gr.Textbox(
  264. label='thing text prompt',
  265. elem_classes='input_text_thing',
  266. interactive=True,
  267. )
  268. stuff_text_input = gr.Textbox(
  269. label='stuff text prompt',
  270. elem_classes='input_text_stuff',
  271. interactive=True,
  272. )
  273. output = gr.Image(
  274. label='Result',
  275. source='upload',
  276. interactive=False,
  277. elem_classes='result',
  278. )
  279. run_button = gr.Button(
  280. 'Run',
  281. elem_classes='run_button',
  282. )
  283. run_button.click(
  284. self.inference,
  285. inputs=[
  286. select_model, image_input, text_input, stuff_text_input
  287. ],
  288. outputs=output,
  289. )
  290. with gr.Row():
  291. example_images = gr.Dataset(
  292. components=[image_input, text_input, stuff_text_input],
  293. samples=[['demo/demo.jpg', 'bench.car', 'tree']])
  294. example_images.click(
  295. fn=self.update,
  296. inputs=example_images,
  297. outputs=[image_input, text_input, stuff_text_input])
  298. def update(self, example):
  299. return gr.Image.update(value=example[0]), \
  300. gr.Textbox.update(label='thing text prompt', value=example[1]), \
  301. gr.Textbox.update(label='stuff text prompt', value=example[2])
  302. def inference(self, model, image, text, stuff_text):
  303. det_inferencer = DetInferencer(
  304. **self.model_info[model], scope='mmdet', device=get_free_device())
  305. results_dict = det_inferencer(
  306. image,
  307. texts=text,
  308. stuff_texts=stuff_text,
  309. return_vis=True,
  310. no_save_vis=True)
  311. vis = results_dict['visualization'][0]
  312. return vis
  313. class OpenVocabSemSegTab(OpenVocabInstanceSegTab):
  314. model_list = ['xdecoder-tiny']
  315. model_info = {
  316. 'xdecoder-tiny': {
  317. 'model':
  318. 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py', # noqa
  319. 'weights':
  320. 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa
  321. }
  322. }
  323. class ReferSegTab(OpenVocabInstanceSegTab):
  324. model_list = ['xdecoder-tiny']
  325. model_info = {
  326. 'xdecoder-tiny': {
  327. 'model':
  328. 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py', # noqa
  329. 'weights':
  330. 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa
  331. }
  332. }
  333. class ImageCaptionTab:
  334. model_list = ['xdecoder-tiny']
  335. model_info = {
  336. 'xdecoder-tiny': {
  337. 'model':
  338. 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py', # noqa
  339. 'weights':
  340. 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa
  341. }
  342. }
  343. def __init__(self) -> None:
  344. self.create_ui()
  345. def create_ui(self):
  346. with gr.Row():
  347. with gr.Column():
  348. select_model = gr.Dropdown(
  349. label='Choose a model',
  350. elem_id='image_caption_models',
  351. elem_classes='select_model',
  352. choices=self.model_list,
  353. value=self.model_list[0],
  354. )
  355. with gr.Column():
  356. image_input = gr.Image(
  357. label='Input',
  358. source='upload',
  359. elem_classes='input_image',
  360. interactive=True,
  361. tool='editor',
  362. )
  363. caption_output = gr.Textbox(
  364. label='Result',
  365. lines=2,
  366. elem_classes='caption_result',
  367. interactive=False,
  368. )
  369. run_button = gr.Button(
  370. 'Run',
  371. elem_classes='run_button',
  372. )
  373. run_button.click(
  374. self.inference,
  375. inputs=[select_model, image_input],
  376. outputs=caption_output,
  377. )
  378. with gr.Row():
  379. example_images = gr.Dataset(
  380. components=[image_input], samples=[['demo/demo.jpg']])
  381. example_images.click(
  382. fn=lambda x: gr.Image.update(value=x[0]),
  383. inputs=example_images,
  384. outputs=image_input)
  385. def inference(self, model, image):
  386. ic_inferencer = ImageCaptionInferencer(
  387. **self.model_info[model], scope='mmdet', device=get_free_device())
  388. results_dict = ic_inferencer(
  389. image, return_vis=False, no_save_vis=True, return_datasample=True)
  390. return results_dict['predictions'][0].pred_caption
  391. class ReferImageCaptionTab(OpenVocabInstanceSegTab):
  392. model_list = ['xdecoder-tiny']
  393. model_info = {
  394. 'xdecoder-tiny': {
  395. 'model':
  396. 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-caption.py', # noqa
  397. 'weights':
  398. 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa
  399. }
  400. }
  401. def create_ui(self):
  402. with gr.Row():
  403. with gr.Column():
  404. select_model = gr.Dropdown(
  405. label='Choose a model',
  406. elem_id='image_caption_models',
  407. elem_classes='select_model',
  408. choices=self.model_list,
  409. value=self.model_list[0],
  410. )
  411. with gr.Column():
  412. image_input = gr.Image(
  413. label='Input',
  414. source='upload',
  415. elem_classes='input_image',
  416. type='filepath',
  417. interactive=True,
  418. tool='editor',
  419. )
  420. text_input = gr.Textbox(
  421. label='text prompt',
  422. elem_classes='input_text',
  423. interactive=True,
  424. )
  425. output = gr.Image(
  426. label='Result',
  427. source='upload',
  428. interactive=False,
  429. elem_classes='result',
  430. )
  431. run_button = gr.Button(
  432. 'Run',
  433. elem_classes='run_button',
  434. )
  435. run_button.click(
  436. self.inference,
  437. inputs=[select_model, image_input, text_input],
  438. outputs=output,
  439. )
  440. with gr.Row():
  441. example_images = gr.Dataset(
  442. components=[image_input, text_input],
  443. samples=[['demo/demo.jpg', 'tree']])
  444. example_images.click(
  445. fn=self.update,
  446. inputs=example_images,
  447. outputs=[image_input, text_input])
  448. def update(self, example):
  449. return gr.Image.update(value=example[0]), gr.Textbox.update(
  450. value=example[1])
  451. def inference(self, model, image, text):
  452. ric_inferencer = RefImageCaptionInferencer(
  453. **self.model_info[model], scope='mmdet', device=get_free_device())
  454. results_dict = ric_inferencer(
  455. image, texts=text, return_vis=True, no_save_vis=True)
  456. vis = results_dict['visualization'][0]
  457. return vis
  458. class TextToImageRetrievalTab:
  459. model_list = ['xdecoder-tiny']
  460. model_info = {
  461. 'xdecoder-tiny': {
  462. 'model':
  463. 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_text-image-retrieval.py', # noqa
  464. 'weights':
  465. 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa
  466. }
  467. }
  468. def __init__(self) -> None:
  469. self.create_ui()
  470. def create_ui(self):
  471. with gr.Row():
  472. with gr.Column():
  473. select_model = gr.Dropdown(
  474. label='Choose a model',
  475. elem_id='t2i_retri_models',
  476. elem_classes='select_model',
  477. choices=self.model_list,
  478. value=self.model_list[0],
  479. )
  480. with gr.Column():
  481. prototype = gr.File(
  482. file_count='multiple', file_types=['image'])
  483. text_input = gr.Textbox(
  484. label='Query',
  485. elem_classes='input_text',
  486. interactive=True,
  487. )
  488. retri_output = gr.Image(
  489. label='Result',
  490. source='upload',
  491. interactive=False,
  492. elem_classes='result',
  493. )
  494. run_button = gr.Button(
  495. 'Run',
  496. elem_classes='run_button',
  497. )
  498. run_button.click(
  499. self.inference,
  500. inputs=[select_model, prototype, text_input],
  501. outputs=retri_output,
  502. )
  503. def inference(self, model, prototype, text):
  504. inputs = [file.name for file in prototype]
  505. retri_inferencer = TextToImageRegionRetrievalInferencer(
  506. **self.model_info[model], scope='mmdet', device=get_free_device())
  507. results_dict = retri_inferencer(
  508. inputs, texts=text, return_vis=True, no_save_vis=True)
  509. vis = results_dict['visualization'][0]
  510. return vis
  511. if __name__ == '__main__':
  512. title = 'MMDetection Inference Demo'
  513. DESCRIPTION = '''# <div align="center">MMDetection Inference Demo </div>
  514. <div align="center">
  515. <img src="https://user-images.githubusercontent.com/45811724/190993591-
  516. bd3f1f11-1c30-4b93-b5f4-05c9ff64ff7f.gif" width="50%"/>
  517. </div>
  518. #### This is an official demo for MMDet. \n
  519. - The first time running requires downloading the weights,
  520. please wait a moment. \n
  521. - OV is mean Open Vocabulary \n
  522. - Refer Seg is mean Referring Expression Segmentation \n
  523. - In Text-Image Region Retrieval, you need to provide n images and
  524. a query text, and the model will predict the most matching image and
  525. its corresponding grounding mask.
  526. '''
  527. with gr.Blocks(analytics_enabled=False, title=title) as demo:
  528. gr.Markdown(DESCRIPTION)
  529. with gr.Tabs():
  530. with gr.TabItem('Detection'):
  531. ObjectDetectionTab()
  532. with gr.TabItem('Instance'):
  533. InstanceSegTab()
  534. with gr.TabItem('Panoptic'):
  535. PanopticSegTab()
  536. with gr.TabItem('Grounding Detection'):
  537. GroundingDetectionTab()
  538. with gr.TabItem('OV Detection'):
  539. OpenVocabObjectDetectionTab()
  540. with gr.TabItem('OV Instance'):
  541. OpenVocabInstanceSegTab()
  542. with gr.TabItem('OV Panoptic'):
  543. OpenVocabPanopticSegTab()
  544. with gr.TabItem('OV SemSeg'):
  545. OpenVocabSemSegTab()
  546. with gr.TabItem('Refer Seg'):
  547. ReferSegTab()
  548. with gr.TabItem('Image Caption'):
  549. ImageCaptionTab()
  550. with gr.TabItem('Refer Caption'):
  551. ReferImageCaptionTab()
  552. with gr.TabItem('Text-Image Region Retrieval'):
  553. TextToImageRetrievalTab()
  554. demo.queue().launch(share=True)