image_caption.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. import copy
  2. import os.path as osp
  3. from typing import Iterable, List, Optional, Tuple, Union
  4. import mmcv
  5. import mmengine
  6. import numpy as np
  7. import torch
  8. from mmengine.dataset import Compose
  9. from rich.progress import track
  10. from mmdet.apis.det_inferencer import DetInferencer, InputsType, PredType
  11. from mmdet.utils import ConfigType
  12. def get_adaptive_scale(img_shape: Tuple[int, int],
  13. min_scale: float = 0.3,
  14. max_scale: float = 3.0) -> float:
  15. """Get adaptive scale according to image shape.
  16. The target scale depends on the the short edge length of the image. If the
  17. short edge length equals 224, the output is 1.0. And output linear scales
  18. according the short edge length.
  19. You can also specify the minimum scale and the maximum scale to limit the
  20. linear scale.
  21. Args:
  22. img_shape (Tuple[int, int]): The shape of the canvas image.
  23. min_scale (float): The minimum scale. Defaults to 0.3.
  24. max_scale (float): The maximum scale. Defaults to 3.0.
  25. Returns:
  26. int: The adaptive scale.
  27. """
  28. short_edge_length = min(img_shape)
  29. scale = short_edge_length / 224.
  30. return min(max(scale, min_scale), max_scale)
  31. class ImageCaptionInferencer(DetInferencer):
  32. DEFAULT_TEXT_CFG = {
  33. 'font_families': 'monospace',
  34. 'colors': 'white',
  35. 'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'),
  36. 'vertical_alignments': 'top',
  37. 'horizontal_alignments': 'left',
  38. }
  39. def visualize(self,
  40. inputs: InputsType,
  41. preds: PredType,
  42. return_vis: bool = False,
  43. show: bool = False,
  44. wait_time: int = 0,
  45. draw_pred: bool = True,
  46. pred_score_thr: float = 0.3,
  47. no_save_vis: bool = False,
  48. img_out_dir: str = '',
  49. **kwargs) -> Union[List[np.ndarray], None]:
  50. if no_save_vis is True:
  51. img_out_dir = ''
  52. if not show and img_out_dir == '' and not return_vis:
  53. return None
  54. if self.visualizer is None:
  55. raise ValueError('Visualization needs the "visualizer" term'
  56. 'defined in the config, but got None.')
  57. results = []
  58. text_cfg = self.DEFAULT_TEXT_CFG
  59. for single_input, pred in zip(inputs, preds):
  60. if isinstance(single_input, str):
  61. img_bytes = mmengine.fileio.get(single_input)
  62. img = mmcv.imfrombytes(img_bytes)
  63. img = img[:, :, ::-1]
  64. img_name = osp.basename(single_input)
  65. elif isinstance(single_input, np.ndarray):
  66. img = single_input.copy()
  67. img_num = str(self.num_visualized_imgs).zfill(8)
  68. img_name = f'{img_num}.jpg'
  69. else:
  70. raise ValueError('Unsupported input type: '
  71. f'{type(single_input)}')
  72. out_file = osp.join(img_out_dir, 'vis',
  73. img_name) if img_out_dir != '' else None
  74. self.visualizer.set_image(img)
  75. img_scale = get_adaptive_scale(img.shape[:2])
  76. text_cfg['font_sizes'] = int(img_scale * 7)
  77. self.visualizer.draw_texts(
  78. pred.pred_caption, torch.tensor([img_scale * 5,
  79. img_scale * 5]), **text_cfg)
  80. drawn_img = self.visualizer.get_image()
  81. self.visualizer.add_datasample(
  82. img_name,
  83. drawn_img,
  84. pred,
  85. show=show,
  86. wait_time=wait_time,
  87. draw_gt=False,
  88. draw_pred=draw_pred,
  89. pred_score_thr=pred_score_thr,
  90. out_file=out_file,
  91. )
  92. results.append(self.visualizer.get_image())
  93. self.num_visualized_imgs += 1
  94. return results
  95. class RefImageCaptionInferencer(ImageCaptionInferencer):
  96. def _init_pipeline(self, cfg: ConfigType) -> Compose:
  97. """Initialize the test pipeline."""
  98. pipeline_cfg = cfg.test_dataloader.dataset.pipeline
  99. # For inference, the key of ``img_id`` is not used.
  100. if 'meta_keys' in pipeline_cfg[-1]:
  101. pipeline_cfg[-1]['meta_keys'] = tuple(
  102. meta_key for meta_key in pipeline_cfg[-1]['meta_keys']
  103. if meta_key != 'img_id')
  104. load_img_idx = self._get_transform_idx(pipeline_cfg,
  105. 'LoadImageFromFile')
  106. if load_img_idx == -1:
  107. raise ValueError(
  108. 'LoadImageFromFile is not found in the test pipeline')
  109. pipeline_cfg[load_img_idx]['type'] = 'mmdet.InferencerLoader'
  110. caption_pipeline = Compose(pipeline_cfg)
  111. grounding_pipeline_cp = copy.deepcopy(pipeline_cfg)
  112. grounding_pipeline_cp[1].scale = cfg.grounding_scale
  113. grounding_pipeline = Compose(grounding_pipeline_cp)
  114. return {
  115. 'grounding_pipeline': grounding_pipeline,
  116. 'caption_pipeline': caption_pipeline
  117. }
  118. def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
  119. """Get batch data from inputs.
  120. Args:
  121. inputs (Iterable): An iterable dataset.
  122. chunk_size (int): Equivalent to batch size.
  123. Yields:
  124. list: batch data.
  125. """
  126. inputs_iter = iter(inputs)
  127. while True:
  128. try:
  129. chunk_data = []
  130. for _ in range(chunk_size):
  131. inputs_ = next(inputs_iter)
  132. if 'img' in inputs_:
  133. ori_inputs_ = inputs_['img']
  134. else:
  135. ori_inputs_ = inputs_['img_path']
  136. chunk_data.append(
  137. (ori_inputs_, self.pipeline['grounding_pipeline'](
  138. copy.deepcopy(inputs_)),
  139. self.pipeline['caption_pipeline'](
  140. copy.deepcopy(inputs_))))
  141. yield chunk_data
  142. except StopIteration:
  143. if chunk_data:
  144. yield chunk_data
  145. break
  146. def __call__(
  147. self,
  148. inputs: InputsType,
  149. batch_size: int = 1,
  150. return_vis: bool = False,
  151. show: bool = False,
  152. wait_time: int = 0,
  153. no_save_vis: bool = False,
  154. draw_pred: bool = True,
  155. pred_score_thr: float = 0.3,
  156. return_datasample: bool = False,
  157. print_result: bool = False,
  158. no_save_pred: bool = True,
  159. out_dir: str = '',
  160. texts: Optional[Union[str, list]] = None,
  161. # by open panoptic task
  162. stuff_texts: Optional[Union[str, list]] = None,
  163. custom_entities: bool = False, # by GLIP
  164. **kwargs) -> dict:
  165. """Call the inferencer.
  166. Args:
  167. inputs (InputsType): Inputs for the inferencer.
  168. batch_size (int): Inference batch size. Defaults to 1.
  169. show (bool): Whether to display the visualization results in a
  170. popup window. Defaults to False.
  171. wait_time (float): The interval of show (s). Defaults to 0.
  172. no_save_vis (bool): Whether to force not to save prediction
  173. vis results. Defaults to False.
  174. draw_pred (bool): Whether to draw predicted bounding boxes.
  175. Defaults to True.
  176. pred_score_thr (float): Minimum score of bboxes to draw.
  177. Defaults to 0.3.
  178. return_datasample (bool): Whether to return results as
  179. :obj:`DetDataSample`. Defaults to False.
  180. print_result (bool): Whether to print the inference result w/o
  181. visualization to the console. Defaults to False.
  182. no_save_pred (bool): Whether to force not to save prediction
  183. results. Defaults to True.
  184. out_file: Dir to save the inference results or
  185. visualization. If left as empty, no file will be saved.
  186. Defaults to ''.
  187. **kwargs: Other keyword arguments passed to :meth:`preprocess`,
  188. :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
  189. Each key in kwargs should be in the corresponding set of
  190. ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
  191. and ``postprocess_kwargs``.
  192. Returns:
  193. dict: Inference and visualization results.
  194. """
  195. assert batch_size == 1
  196. (
  197. preprocess_kwargs,
  198. forward_kwargs,
  199. visualize_kwargs,
  200. postprocess_kwargs,
  201. ) = self._dispatch_kwargs(**kwargs)
  202. ori_inputs = self._inputs_to_list(inputs)
  203. if isinstance(texts, str):
  204. texts = [texts] * len(ori_inputs)
  205. for i in range(len(texts)):
  206. if isinstance(ori_inputs[i], str):
  207. ori_inputs[i] = {
  208. 'text': texts[i],
  209. 'img_path': ori_inputs[i],
  210. 'custom_entities': custom_entities
  211. }
  212. else:
  213. ori_inputs[i] = {
  214. 'text': texts[i],
  215. 'img': ori_inputs[i],
  216. 'custom_entities': custom_entities
  217. }
  218. inputs = self.preprocess(
  219. ori_inputs, batch_size=batch_size, **preprocess_kwargs)
  220. results_dict = {'predictions': [], 'visualization': []}
  221. for ori_inputs, grounding_data, caption_data in track(
  222. inputs, description='Inference'):
  223. self.model.sem_seg_head.task = 'ref-seg'
  224. self.model.sem_seg_head.predictor.task = 'ref-seg'
  225. preds = self.forward(grounding_data, **forward_kwargs)
  226. for data_sample, pred_datasmaple in zip(
  227. caption_data['data_samples'], preds):
  228. data_sample.pred_instances = pred_datasmaple.pred_instances
  229. data_sample.set_metainfo({
  230. 'grounding_img_shape':
  231. pred_datasmaple.metainfo['img_shape']
  232. })
  233. self.model.sem_seg_head.task = 'caption'
  234. self.model.sem_seg_head.predictor.task = 'caption'
  235. preds = self.forward(caption_data, **forward_kwargs)
  236. if isinstance(ori_inputs, dict):
  237. ori_inputs = ori_inputs['img_path']
  238. visualization = self.visualize(
  239. ori_inputs,
  240. preds,
  241. return_vis=return_vis,
  242. show=show,
  243. wait_time=wait_time,
  244. draw_pred=draw_pred,
  245. pred_score_thr=pred_score_thr,
  246. no_save_vis=no_save_vis,
  247. img_out_dir=out_dir,
  248. **visualize_kwargs)
  249. results = self.postprocess(
  250. preds,
  251. visualization,
  252. return_datasample=return_datasample,
  253. print_result=print_result,
  254. no_save_pred=no_save_pred,
  255. pred_out_dir=out_dir,
  256. **postprocess_kwargs)
  257. results_dict['predictions'].extend(results['predictions'])
  258. if results['visualization'] is not None:
  259. results_dict['visualization'].extend(results['visualization'])
  260. return results_dict