glip.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import re
  4. import warnings
  5. from typing import Tuple
  6. import torch
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.structures import SampleList
  10. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
  11. from .single_stage import SingleStageDetector
  12. def find_noun_phrases(caption: str) -> list:
  13. """Find noun phrases in a caption using nltk.
  14. Args:
  15. caption (str): The caption to analyze.
  16. Returns:
  17. list: List of noun phrases found in the caption.
  18. Examples:
  19. >>> caption = 'There is two cat and a remote in the picture'
  20. >>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture']
  21. """
  22. try:
  23. import nltk
  24. nltk.download('punkt')
  25. nltk.download('averaged_perceptron_tagger')
  26. except ImportError:
  27. raise RuntimeError('nltk is not installed, please install it by: '
  28. 'pip install nltk.')
  29. caption = caption.lower()
  30. tokens = nltk.word_tokenize(caption)
  31. pos_tags = nltk.pos_tag(tokens)
  32. grammar = 'NP: {<DT>?<JJ.*>*<NN.*>+}'
  33. cp = nltk.RegexpParser(grammar)
  34. result = cp.parse(pos_tags)
  35. noun_phrases = []
  36. for subtree in result.subtrees():
  37. if subtree.label() == 'NP':
  38. noun_phrases.append(' '.join(t[0] for t in subtree.leaves()))
  39. return noun_phrases
  40. def remove_punctuation(text: str) -> str:
  41. """Remove punctuation from a text.
  42. Args:
  43. text (str): The input text.
  44. Returns:
  45. str: The text with punctuation removed.
  46. """
  47. punctuation = [
  48. '|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', '\'', '\"', '’',
  49. '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.'
  50. ]
  51. for p in punctuation:
  52. text = text.replace(p, '')
  53. return text.strip()
  54. def run_ner(caption: str) -> Tuple[list, list]:
  55. """Run NER on a caption and return the tokens and noun phrases.
  56. Args:
  57. caption (str): The input caption.
  58. Returns:
  59. Tuple[List, List]: A tuple containing the tokens and noun phrases.
  60. - tokens_positive (List): A list of token positions.
  61. - noun_phrases (List): A list of noun phrases.
  62. """
  63. noun_phrases = find_noun_phrases(caption)
  64. noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
  65. noun_phrases = [phrase for phrase in noun_phrases if phrase != '']
  66. relevant_phrases = noun_phrases
  67. labels = noun_phrases
  68. tokens_positive = []
  69. for entity, label in zip(relevant_phrases, labels):
  70. try:
  71. # search all occurrences and mark them as different entities
  72. # TODO: Not Robust
  73. for m in re.finditer(entity, caption.lower()):
  74. tokens_positive.append([[m.start(), m.end()]])
  75. except Exception:
  76. print('noun entities:', noun_phrases)
  77. print('entity:', entity)
  78. print('caption:', caption.lower())
  79. return tokens_positive, noun_phrases
  80. def create_positive_map(tokenized,
  81. tokens_positive: list,
  82. max_num_entities: int = 256) -> Tensor:
  83. """construct a map such that positive_map[i,j] = True
  84. if box i is associated to token j
  85. Args:
  86. tokenized: The tokenized input.
  87. tokens_positive (list): A list of token ranges
  88. associated with positive boxes.
  89. max_num_entities (int, optional): The maximum number of entities.
  90. Defaults to 256.
  91. Returns:
  92. torch.Tensor: The positive map.
  93. Raises:
  94. Exception: If an error occurs during token-to-char mapping.
  95. """
  96. positive_map = torch.zeros((len(tokens_positive), max_num_entities),
  97. dtype=torch.float)
  98. for j, tok_list in enumerate(tokens_positive):
  99. for (beg, end) in tok_list:
  100. try:
  101. beg_pos = tokenized.char_to_token(beg)
  102. end_pos = tokenized.char_to_token(end - 1)
  103. except Exception as e:
  104. print('beg:', beg, 'end:', end)
  105. print('token_positive:', tokens_positive)
  106. raise e
  107. if beg_pos is None:
  108. try:
  109. beg_pos = tokenized.char_to_token(beg + 1)
  110. if beg_pos is None:
  111. beg_pos = tokenized.char_to_token(beg + 2)
  112. except Exception:
  113. beg_pos = None
  114. if end_pos is None:
  115. try:
  116. end_pos = tokenized.char_to_token(end - 2)
  117. if end_pos is None:
  118. end_pos = tokenized.char_to_token(end - 3)
  119. except Exception:
  120. end_pos = None
  121. if beg_pos is None or end_pos is None:
  122. continue
  123. assert beg_pos is not None and end_pos is not None
  124. positive_map[j, beg_pos:end_pos + 1].fill_(1)
  125. return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
  126. def create_positive_map_label_to_token(positive_map: Tensor,
  127. plus: int = 0) -> dict:
  128. """Create a dictionary mapping the label to the token.
  129. Args:
  130. positive_map (Tensor): The positive map tensor.
  131. plus (int, optional): Value added to the label for indexing.
  132. Defaults to 0.
  133. Returns:
  134. dict: The dictionary mapping the label to the token.
  135. """
  136. positive_map_label_to_token = {}
  137. for i in range(len(positive_map)):
  138. positive_map_label_to_token[i + plus] = torch.nonzero(
  139. positive_map[i], as_tuple=True)[0].tolist()
  140. return positive_map_label_to_token
  141. @MODELS.register_module()
  142. class GLIP(SingleStageDetector):
  143. """Implementation of `GLIP <https://arxiv.org/abs/2112.03857>`_
  144. Args:
  145. backbone (:obj:`ConfigDict` or dict): The backbone config.
  146. neck (:obj:`ConfigDict` or dict): The neck config.
  147. bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
  148. language_model (:obj:`ConfigDict` or dict): The language model config.
  149. train_cfg (:obj:`ConfigDict` or dict, optional): The training config
  150. of GLIP. Defaults to None.
  151. test_cfg (:obj:`ConfigDict` or dict, optional): The testing config
  152. of GLIP. Defaults to None.
  153. data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
  154. :class:`DetDataPreprocessor` to process the input data.
  155. Defaults to None.
  156. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
  157. list[dict], optional): Initialization config dict.
  158. Defaults to None.
  159. """
  160. def __init__(self,
  161. backbone: ConfigType,
  162. neck: ConfigType,
  163. bbox_head: ConfigType,
  164. language_model: ConfigType,
  165. train_cfg: OptConfigType = None,
  166. test_cfg: OptConfigType = None,
  167. data_preprocessor: OptConfigType = None,
  168. init_cfg: OptMultiConfig = None) -> None:
  169. super().__init__(
  170. backbone=backbone,
  171. neck=neck,
  172. bbox_head=bbox_head,
  173. train_cfg=train_cfg,
  174. test_cfg=test_cfg,
  175. data_preprocessor=data_preprocessor,
  176. init_cfg=init_cfg)
  177. self.language_model = MODELS.build(language_model)
  178. self._text_prompts = None
  179. self._positive_maps = None
  180. self._language_dict_features = None
  181. self._entities = None
  182. def get_tokens_positive_and_prompts(
  183. self,
  184. original_caption: str,
  185. custom_entities: bool = False) -> Tuple[dict, str]:
  186. """Get the tokens positive and prompts for the caption."""
  187. if isinstance(original_caption, (list, tuple)) or custom_entities:
  188. if custom_entities and isinstance(original_caption, str):
  189. if not original_caption.endswith('.'):
  190. original_caption = original_caption + ' . '
  191. original_caption = original_caption.split(' . ')
  192. original_caption = list(
  193. filter(lambda x: len(x) > 0, original_caption))
  194. caption_string = ''
  195. tokens_positive = []
  196. seperation_tokens = ' . '
  197. for word in original_caption:
  198. tokens_positive.append(
  199. [[len(caption_string),
  200. len(caption_string) + len(word)]])
  201. caption_string += word
  202. caption_string += seperation_tokens
  203. tokenized = self.language_model.tokenizer([caption_string],
  204. return_tensors='pt')
  205. self._entities = original_caption
  206. else:
  207. if not original_caption.endswith('.'):
  208. original_caption = original_caption + ' . '
  209. tokenized = self.language_model.tokenizer([original_caption],
  210. return_tensors='pt')
  211. tokens_positive, noun_phrases = run_ner(original_caption)
  212. self._entities = noun_phrases
  213. caption_string = original_caption
  214. positive_map = create_positive_map(tokenized, tokens_positive)
  215. positive_map_label_to_token = create_positive_map_label_to_token(
  216. positive_map, plus=1)
  217. return positive_map_label_to_token, caption_string
  218. def predict(self,
  219. batch_inputs: Tensor,
  220. batch_data_samples: SampleList,
  221. rescale: bool = True) -> SampleList:
  222. """Predict results from a batch of inputs and data samples with post-
  223. processing.
  224. Args:
  225. batch_inputs (Tensor): Inputs with shape (N, C, H, W).
  226. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  227. Samples. It usually includes information such as
  228. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  229. rescale (bool): Whether to rescale the results.
  230. Defaults to True.
  231. Returns:
  232. list[:obj:`DetDataSample`]: Detection results of the
  233. input images. Each DetDataSample usually contain
  234. 'pred_instances'. And the ``pred_instances`` usually
  235. contains following keys.
  236. - scores (Tensor): Classification scores, has a shape
  237. (num_instance, )
  238. - labels (Tensor): Labels of bboxes, has a shape
  239. (num_instances, ).
  240. - label_names (List[str]): Label names of bboxes.
  241. - bboxes (Tensor): Has a shape (num_instances, 4),
  242. the last dimension 4 arrange as (x1, y1, x2, y2).
  243. """
  244. text_prompts = [
  245. data_samples.text for data_samples in batch_data_samples
  246. ]
  247. if 'custom_entities' in batch_data_samples[0]:
  248. # Assuming that the `custom_entities` flag
  249. # inside a batch is always the same. For single image inference
  250. custom_entities = batch_data_samples[0].custom_entities
  251. else:
  252. custom_entities = False
  253. if text_prompts != self._text_prompts:
  254. # avoid redundant computation
  255. self._text_prompts = text_prompts
  256. if len(set(text_prompts)) == 1:
  257. # All the text prompts are the same,
  258. # so there is no need to calculate them multiple times.
  259. _positive_maps_and_prompts = [
  260. self.get_tokens_positive_and_prompts(
  261. text_prompts[0], custom_entities)
  262. ] * len(batch_inputs)
  263. else:
  264. _positive_maps_and_prompts = [
  265. self.get_tokens_positive_and_prompts(
  266. text_prompt, custom_entities)
  267. for text_prompt in text_prompts
  268. ]
  269. self._positive_maps, text_prompts = zip(
  270. *_positive_maps_and_prompts)
  271. self._language_dict_features = self.language_model(text_prompts)
  272. for i, data_samples in enumerate(batch_data_samples):
  273. data_samples.token_positive_map = self._positive_maps[i]
  274. visual_features = self.extract_feat(batch_inputs)
  275. results_list = self.bbox_head.predict(
  276. visual_features,
  277. copy.deepcopy(self._language_dict_features),
  278. batch_data_samples,
  279. rescale=rescale)
  280. for data_sample, pred_instances in zip(batch_data_samples,
  281. results_list):
  282. if len(pred_instances) > 0:
  283. label_names = []
  284. for labels in pred_instances.labels:
  285. if labels >= len(self._entities):
  286. warnings.warn(
  287. 'The unexpected output indicates an issue with '
  288. 'named entity recognition. You can try '
  289. 'setting custom_entities=True and running '
  290. 'again to see if it helps.')
  291. label_names.append('unobject')
  292. else:
  293. label_names.append(self._entities[labels])
  294. # for visualization
  295. pred_instances.label_names = label_names
  296. data_sample.pred_instances = pred_instances
  297. return batch_data_samples