texttoimage_regionretrieval_inferencer.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import copy
  2. from typing import Iterable, Optional, Union
  3. import torch
  4. from mmengine.dataset import Compose
  5. from rich.progress import track
  6. from mmdet.apis.det_inferencer import DetInferencer, InputsType
  7. from mmdet.utils import ConfigType
  8. class TextToImageRegionRetrievalInferencer(DetInferencer):
  9. def _init_pipeline(self, cfg: ConfigType) -> Compose:
  10. """Initialize the test pipeline."""
  11. pipeline_cfg = cfg.test_dataloader.dataset.pipeline
  12. # For inference, the key of ``img_id`` is not used.
  13. if 'meta_keys' in pipeline_cfg[-1]:
  14. pipeline_cfg[-1]['meta_keys'] = tuple(
  15. meta_key for meta_key in pipeline_cfg[-1]['meta_keys']
  16. if meta_key != 'img_id')
  17. load_img_idx = self._get_transform_idx(pipeline_cfg,
  18. 'LoadImageFromFile')
  19. if load_img_idx == -1:
  20. raise ValueError(
  21. 'LoadImageFromFile is not found in the test pipeline')
  22. pipeline_cfg[load_img_idx]['type'] = 'mmdet.InferencerLoader'
  23. retrieval_pipeline = Compose(pipeline_cfg)
  24. grounding_pipeline_cp = copy.deepcopy(pipeline_cfg)
  25. grounding_pipeline_cp[1].scale = cfg.grounding_scale
  26. grounding_pipeline = Compose(grounding_pipeline_cp)
  27. return {
  28. 'grounding_pipeline': grounding_pipeline,
  29. 'retrieval_pipeline': retrieval_pipeline
  30. }
  31. def _get_chunk_data(self, inputs: Iterable, pipeline, chunk_size: int):
  32. """Get batch data from inputs.
  33. Args:
  34. inputs (Iterable): An iterable dataset.
  35. chunk_size (int): Equivalent to batch size.
  36. Yields:
  37. list: batch data.
  38. """
  39. inputs_iter = iter(inputs)
  40. while True:
  41. try:
  42. chunk_data = []
  43. for _ in range(chunk_size):
  44. inputs_ = next(inputs_iter)
  45. chunk_data.append(
  46. (inputs_, pipeline(copy.deepcopy(inputs_))))
  47. yield chunk_data
  48. except StopIteration:
  49. if chunk_data:
  50. yield chunk_data
  51. break
  52. def preprocess(self,
  53. inputs: InputsType,
  54. pipeline,
  55. batch_size: int = 1,
  56. **kwargs):
  57. """Process the inputs into a model-feedable format.
  58. Customize your preprocess by overriding this method. Preprocess should
  59. return an iterable object, of which each item will be used as the
  60. input of ``model.test_step``.
  61. ``BaseInferencer.preprocess`` will return an iterable chunked data,
  62. which will be used in __call__ like this:
  63. .. code-block:: python
  64. def __call__(self, inputs, batch_size=1, **kwargs):
  65. chunked_data = self.preprocess(inputs, batch_size, **kwargs)
  66. for batch in chunked_data:
  67. preds = self.forward(batch, **kwargs)
  68. Args:
  69. inputs (InputsType): Inputs given by user.
  70. batch_size (int): batch size. Defaults to 1.
  71. Yields:
  72. Any: Data processed by the ``pipeline`` and ``collate_fn``.
  73. """
  74. chunked_data = self._get_chunk_data(inputs, pipeline, batch_size)
  75. yield from map(self.collate_fn, chunked_data)
  76. def __call__(
  77. self,
  78. inputs: InputsType,
  79. batch_size: int = 1,
  80. return_vis: bool = False,
  81. show: bool = False,
  82. wait_time: int = 0,
  83. no_save_vis: bool = False,
  84. draw_pred: bool = True,
  85. pred_score_thr: float = 0.3,
  86. return_datasample: bool = False,
  87. print_result: bool = False,
  88. no_save_pred: bool = True,
  89. out_dir: str = '',
  90. texts: Optional[Union[str, list]] = None,
  91. # by open panoptic task
  92. stuff_texts: Optional[Union[str, list]] = None,
  93. custom_entities: bool = False, # by GLIP
  94. **kwargs) -> dict:
  95. """Call the inferencer.
  96. Args:
  97. inputs (InputsType): Inputs for the inferencer.
  98. batch_size (int): Inference batch size. Defaults to 1.
  99. show (bool): Whether to display the visualization results in a
  100. popup window. Defaults to False.
  101. wait_time (float): The interval of show (s). Defaults to 0.
  102. no_save_vis (bool): Whether to force not to save prediction
  103. vis results. Defaults to False.
  104. draw_pred (bool): Whether to draw predicted bounding boxes.
  105. Defaults to True.
  106. pred_score_thr (float): Minimum score of bboxes to draw.
  107. Defaults to 0.3.
  108. return_datasample (bool): Whether to return results as
  109. :obj:`DetDataSample`. Defaults to False.
  110. print_result (bool): Whether to print the inference result w/o
  111. visualization to the console. Defaults to False.
  112. no_save_pred (bool): Whether to force not to save prediction
  113. results. Defaults to True.
  114. out_file: Dir to save the inference results or
  115. visualization. If left as empty, no file will be saved.
  116. Defaults to ''.
  117. **kwargs: Other keyword arguments passed to :meth:`preprocess`,
  118. :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
  119. Each key in kwargs should be in the corresponding set of
  120. ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
  121. and ``postprocess_kwargs``.
  122. Returns:
  123. dict: Inference and visualization results.
  124. """
  125. (
  126. preprocess_kwargs,
  127. forward_kwargs,
  128. visualize_kwargs,
  129. postprocess_kwargs,
  130. ) = self._dispatch_kwargs(**kwargs)
  131. ori_inputs = self._inputs_to_list(inputs)
  132. if isinstance(texts, str):
  133. texts = [texts] * len(ori_inputs)
  134. for i in range(len(texts)):
  135. ori_inputs[i] = {
  136. 'img_path': ori_inputs[i],
  137. 'text': texts[i],
  138. 'custom_entities': False
  139. }
  140. inputs = self.preprocess(
  141. ori_inputs,
  142. pipeline=self.pipeline['retrieval_pipeline'],
  143. batch_size=batch_size,
  144. **preprocess_kwargs)
  145. self.model.sem_seg_head._force_not_use_cache = True
  146. pred_scores = []
  147. for _, retrieval_data in track(inputs, description='Inference'):
  148. preds = self.forward(retrieval_data, **forward_kwargs)
  149. pred_scores.append(preds[0].pred_score)
  150. pred_score = torch.cat(pred_scores)
  151. pred_score = torch.softmax(pred_score, dim=0)
  152. max_id = torch.argmax(pred_score)
  153. retrieval_ori_input = ori_inputs[max_id.item()]
  154. max_prob = round(pred_score[max_id].item(), 3)
  155. print(
  156. 'The image that best matches the given text is '
  157. f"{retrieval_ori_input['img_path']} and probability is {max_prob}")
  158. inputs = self.preprocess([retrieval_ori_input],
  159. pipeline=self.pipeline['grounding_pipeline'],
  160. batch_size=1,
  161. **preprocess_kwargs)
  162. self.model.task = 'ref-seg'
  163. self.model.sem_seg_head.task = 'ref-seg'
  164. self.model.sem_seg_head.predictor.task = 'ref-seg'
  165. ori_inputs, grounding_data = next(inputs)
  166. if isinstance(ori_inputs, dict):
  167. ori_inputs = ori_inputs['img_path']
  168. preds = self.forward(grounding_data, **forward_kwargs)
  169. visualization = self.visualize(
  170. ori_inputs,
  171. preds,
  172. return_vis=return_vis,
  173. show=show,
  174. wait_time=wait_time,
  175. draw_pred=draw_pred,
  176. pred_score_thr=pred_score_thr,
  177. no_save_vis=no_save_vis,
  178. img_out_dir=out_dir,
  179. **visualize_kwargs)
  180. results = self.postprocess(
  181. preds,
  182. visualization,
  183. return_datasample=return_datasample,
  184. print_result=print_result,
  185. no_save_pred=no_save_pred,
  186. pred_out_dir=out_dir,
  187. **postprocess_kwargs)
  188. if results['visualization'] is not None:
  189. results['visualization'] = results['visualization']
  190. return results