123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import collections
- import os.path as osp
- import random
- from typing import Dict, List
- import mmengine
- from mmengine.dataset import BaseDataset
- from mmdet.registry import DATASETS
- @DATASETS.register_module()
- class RefCocoDataset(BaseDataset):
- """RefCOCO dataset.
- The `Refcoco` and `Refcoco+` dataset is based on
- `ReferItGame: Referring to Objects in Photographs of Natural Scenes
- <http://tamaraberg.com/papers/referit.pdf>`_.
- The `Refcocog` dataset is based on
- `Generation and Comprehension of Unambiguous Object Descriptions
- <https://arxiv.org/abs/1511.02283>`_.
- Args:
- ann_file (str): Annotation file path.
- data_root (str): The root directory for ``data_prefix`` and
- ``ann_file``. Defaults to ''.
- data_prefix (str): Prefix for training data.
- split_file (str): Split file path.
- split (str): Split name. Defaults to 'train'.
- text_mode (str): Text mode. Defaults to 'random'.
- **kwargs: Other keyword arguments in :class:`BaseDataset`.
- """
- def __init__(self,
- data_root: str,
- ann_file: str,
- split_file: str,
- data_prefix: Dict,
- split: str = 'train',
- text_mode: str = 'random',
- **kwargs):
- self.split_file = split_file
- self.split = split
- assert text_mode in ['original', 'random', 'concat', 'select_first']
- self.text_mode = text_mode
- super().__init__(
- data_root=data_root,
- data_prefix=data_prefix,
- ann_file=ann_file,
- **kwargs,
- )
- def _join_prefix(self):
- if not mmengine.is_abs(self.split_file) and self.split_file:
- self.split_file = osp.join(self.data_root, self.split_file)
- return super()._join_prefix()
- def _init_refs(self):
- """Initialize the refs for RefCOCO."""
- anns, imgs = {}, {}
- for ann in self.instances['annotations']:
- anns[ann['id']] = ann
- for img in self.instances['images']:
- imgs[img['id']] = img
- refs, ref_to_ann = {}, {}
- for ref in self.splits:
- # ids
- ref_id = ref['ref_id']
- ann_id = ref['ann_id']
- # add mapping related to ref
- refs[ref_id] = ref
- ref_to_ann[ref_id] = anns[ann_id]
- self.refs = refs
- self.ref_to_ann = ref_to_ann
- def load_data_list(self) -> List[dict]:
- """Load data list."""
- self.splits = mmengine.load(self.split_file, file_format='pkl')
- self.instances = mmengine.load(self.ann_file, file_format='json')
- self._init_refs()
- img_prefix = self.data_prefix['img_path']
- ref_ids = [
- ref['ref_id'] for ref in self.splits if ref['split'] == self.split
- ]
- full_anno = []
- for ref_id in ref_ids:
- ref = self.refs[ref_id]
- ann = self.ref_to_ann[ref_id]
- ann.update(ref)
- full_anno.append(ann)
- image_id_list = []
- final_anno = {}
- for anno in full_anno:
- image_id_list.append(anno['image_id'])
- final_anno[anno['ann_id']] = anno
- annotations = [value for key, value in final_anno.items()]
- coco_train_id = []
- image_annot = {}
- for i in range(len(self.instances['images'])):
- coco_train_id.append(self.instances['images'][i]['id'])
- image_annot[self.instances['images'][i]
- ['id']] = self.instances['images'][i]
- images = []
- for image_id in list(set(image_id_list)):
- images += [image_annot[image_id]]
- data_list = []
- grounding_dict = collections.defaultdict(list)
- for anno in annotations:
- image_id = int(anno['image_id'])
- grounding_dict[image_id].append(anno)
- join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
- for image in images:
- img_id = image['id']
- instances = []
- sentences = []
- for grounding_anno in grounding_dict[img_id]:
- texts = [x['raw'].lower() for x in grounding_anno['sentences']]
- # random select one text
- if self.text_mode == 'random':
- idx = random.randint(0, len(texts) - 1)
- text = [texts[idx]]
- # concat all texts
- elif self.text_mode == 'concat':
- text = [''.join(texts)]
- # select the first text
- elif self.text_mode == 'select_first':
- text = [texts[0]]
- # use all texts
- elif self.text_mode == 'original':
- text = texts
- else:
- raise ValueError(f'Invalid text mode "{self.text_mode}".')
- ins = [{
- 'mask': grounding_anno['segmentation'],
- 'ignore_flag': 0
- }] * len(text)
- instances.extend(ins)
- sentences.extend(text)
- data_info = {
- 'img_path': join_path(img_prefix, image['file_name']),
- 'img_id': img_id,
- 'instances': instances,
- 'text': sentences
- }
- data_list.append(data_info)
- if len(data_list) == 0:
- raise ValueError(f'No sample in split "{self.split}".')
- return data_list
|