refcoco.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import collections
  3. import os.path as osp
  4. import random
  5. from typing import Dict, List
  6. import mmengine
  7. from mmengine.dataset import BaseDataset
  8. from mmdet.registry import DATASETS
  9. @DATASETS.register_module()
  10. class RefCocoDataset(BaseDataset):
  11. """RefCOCO dataset.
  12. The `Refcoco` and `Refcoco+` dataset is based on
  13. `ReferItGame: Referring to Objects in Photographs of Natural Scenes
  14. <http://tamaraberg.com/papers/referit.pdf>`_.
  15. The `Refcocog` dataset is based on
  16. `Generation and Comprehension of Unambiguous Object Descriptions
  17. <https://arxiv.org/abs/1511.02283>`_.
  18. Args:
  19. ann_file (str): Annotation file path.
  20. data_root (str): The root directory for ``data_prefix`` and
  21. ``ann_file``. Defaults to ''.
  22. data_prefix (str): Prefix for training data.
  23. split_file (str): Split file path.
  24. split (str): Split name. Defaults to 'train'.
  25. text_mode (str): Text mode. Defaults to 'random'.
  26. **kwargs: Other keyword arguments in :class:`BaseDataset`.
  27. """
  28. def __init__(self,
  29. data_root: str,
  30. ann_file: str,
  31. split_file: str,
  32. data_prefix: Dict,
  33. split: str = 'train',
  34. text_mode: str = 'random',
  35. **kwargs):
  36. self.split_file = split_file
  37. self.split = split
  38. assert text_mode in ['original', 'random', 'concat', 'select_first']
  39. self.text_mode = text_mode
  40. super().__init__(
  41. data_root=data_root,
  42. data_prefix=data_prefix,
  43. ann_file=ann_file,
  44. **kwargs,
  45. )
  46. def _join_prefix(self):
  47. if not mmengine.is_abs(self.split_file) and self.split_file:
  48. self.split_file = osp.join(self.data_root, self.split_file)
  49. return super()._join_prefix()
  50. def _init_refs(self):
  51. """Initialize the refs for RefCOCO."""
  52. anns, imgs = {}, {}
  53. for ann in self.instances['annotations']:
  54. anns[ann['id']] = ann
  55. for img in self.instances['images']:
  56. imgs[img['id']] = img
  57. refs, ref_to_ann = {}, {}
  58. for ref in self.splits:
  59. # ids
  60. ref_id = ref['ref_id']
  61. ann_id = ref['ann_id']
  62. # add mapping related to ref
  63. refs[ref_id] = ref
  64. ref_to_ann[ref_id] = anns[ann_id]
  65. self.refs = refs
  66. self.ref_to_ann = ref_to_ann
  67. def load_data_list(self) -> List[dict]:
  68. """Load data list."""
  69. self.splits = mmengine.load(self.split_file, file_format='pkl')
  70. self.instances = mmengine.load(self.ann_file, file_format='json')
  71. self._init_refs()
  72. img_prefix = self.data_prefix['img_path']
  73. ref_ids = [
  74. ref['ref_id'] for ref in self.splits if ref['split'] == self.split
  75. ]
  76. full_anno = []
  77. for ref_id in ref_ids:
  78. ref = self.refs[ref_id]
  79. ann = self.ref_to_ann[ref_id]
  80. ann.update(ref)
  81. full_anno.append(ann)
  82. image_id_list = []
  83. final_anno = {}
  84. for anno in full_anno:
  85. image_id_list.append(anno['image_id'])
  86. final_anno[anno['ann_id']] = anno
  87. annotations = [value for key, value in final_anno.items()]
  88. coco_train_id = []
  89. image_annot = {}
  90. for i in range(len(self.instances['images'])):
  91. coco_train_id.append(self.instances['images'][i]['id'])
  92. image_annot[self.instances['images'][i]
  93. ['id']] = self.instances['images'][i]
  94. images = []
  95. for image_id in list(set(image_id_list)):
  96. images += [image_annot[image_id]]
  97. data_list = []
  98. grounding_dict = collections.defaultdict(list)
  99. for anno in annotations:
  100. image_id = int(anno['image_id'])
  101. grounding_dict[image_id].append(anno)
  102. join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
  103. for image in images:
  104. img_id = image['id']
  105. instances = []
  106. sentences = []
  107. for grounding_anno in grounding_dict[img_id]:
  108. texts = [x['raw'].lower() for x in grounding_anno['sentences']]
  109. # random select one text
  110. if self.text_mode == 'random':
  111. idx = random.randint(0, len(texts) - 1)
  112. text = [texts[idx]]
  113. # concat all texts
  114. elif self.text_mode == 'concat':
  115. text = [''.join(texts)]
  116. # select the first text
  117. elif self.text_mode == 'select_first':
  118. text = [texts[0]]
  119. # use all texts
  120. elif self.text_mode == 'original':
  121. text = texts
  122. else:
  123. raise ValueError(f'Invalid text mode "{self.text_mode}".')
  124. ins = [{
  125. 'mask': grounding_anno['segmentation'],
  126. 'ignore_flag': 0
  127. }] * len(text)
  128. instances.extend(ins)
  129. sentences.extend(text)
  130. data_info = {
  131. 'img_path': join_path(img_prefix, image['file_name']),
  132. 'img_id': img_id,
  133. 'instances': instances,
  134. 'text': sentences
  135. }
  136. data_list.append(data_info)
  137. if len(data_list) == 0:
  138. raise ValueError(f'No sample in split "{self.split}".')
  139. return data_list