dsdl.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os
  3. from typing import List
  4. from mmdet.registry import DATASETS
  5. from .base_det_dataset import BaseDetDataset
  6. try:
  7. from dsdl.dataset import DSDLDataset
  8. except ImportError:
  9. DSDLDataset = None
  10. @DATASETS.register_module()
  11. class DSDLDetDataset(BaseDetDataset):
  12. """Dataset for dsdl detection.
  13. Args:
  14. with_bbox(bool): Load bbox or not, defaults to be True.
  15. with_polygon(bool): Load polygon or not, defaults to be False.
  16. with_mask(bool): Load seg map mask or not, defaults to be False.
  17. with_imagelevel_label(bool): Load image level label or not,
  18. defaults to be False.
  19. with_hierarchy(bool): Load hierarchy information or not,
  20. defaults to be False.
  21. specific_key_path(dict): Path of specific key which can not
  22. be loaded by it's field name.
  23. pre_transform(dict): pre-transform functions before loading.
  24. """
  25. METAINFO = {}
  26. def __init__(self,
  27. with_bbox: bool = True,
  28. with_polygon: bool = False,
  29. with_mask: bool = False,
  30. with_imagelevel_label: bool = False,
  31. with_hierarchy: bool = False,
  32. specific_key_path: dict = {},
  33. pre_transform: dict = {},
  34. **kwargs) -> None:
  35. if DSDLDataset is None:
  36. raise RuntimeError(
  37. 'Package dsdl is not installed. Please run "pip install dsdl".'
  38. )
  39. self.with_hierarchy = with_hierarchy
  40. self.specific_key_path = specific_key_path
  41. loc_config = dict(type='LocalFileReader', working_dir='')
  42. if kwargs.get('data_root'):
  43. kwargs['ann_file'] = os.path.join(kwargs['data_root'],
  44. kwargs['ann_file'])
  45. self.required_fields = ['Image', 'ImageShape', 'Label', 'ignore_flag']
  46. if with_bbox:
  47. self.required_fields.append('Bbox')
  48. if with_polygon:
  49. self.required_fields.append('Polygon')
  50. if with_mask:
  51. self.required_fields.append('LabelMap')
  52. if with_imagelevel_label:
  53. self.required_fields.append('image_level_labels')
  54. assert 'image_level_labels' in specific_key_path.keys(
  55. ), '`image_level_labels` not specified in `specific_key_path` !'
  56. self.extra_keys = [
  57. key for key in self.specific_key_path.keys()
  58. if key not in self.required_fields
  59. ]
  60. self.dsdldataset = DSDLDataset(
  61. dsdl_yaml=kwargs['ann_file'],
  62. location_config=loc_config,
  63. required_fields=self.required_fields,
  64. specific_key_path=specific_key_path,
  65. transform=pre_transform,
  66. )
  67. BaseDetDataset.__init__(self, **kwargs)
  68. def load_data_list(self) -> List[dict]:
  69. """Load data info from an dsdl yaml file named as ``self.ann_file``
  70. Returns:
  71. List[dict]: A list of data info.
  72. """
  73. if self.with_hierarchy:
  74. # get classes_names and relation_matrix
  75. classes_names, relation_matrix = \
  76. self.dsdldataset.class_dom.get_hierarchy_info()
  77. self._metainfo['classes'] = tuple(classes_names)
  78. self._metainfo['RELATION_MATRIX'] = relation_matrix
  79. else:
  80. self._metainfo['classes'] = tuple(self.dsdldataset.class_names)
  81. data_list = []
  82. for i, data in enumerate(self.dsdldataset):
  83. # basic image info, including image id, path and size.
  84. datainfo = dict(
  85. img_id=i,
  86. img_path=os.path.join(self.data_prefix['img_path'],
  87. data['Image'][0].location),
  88. width=data['ImageShape'][0].width,
  89. height=data['ImageShape'][0].height,
  90. )
  91. # get image label info
  92. if 'image_level_labels' in data.keys():
  93. if self.with_hierarchy:
  94. # get leaf node name when using hierarchy classes
  95. datainfo['image_level_labels'] = [
  96. self._metainfo['classes'].index(i.leaf_node_name)
  97. for i in data['image_level_labels']
  98. ]
  99. else:
  100. datainfo['image_level_labels'] = [
  101. self._metainfo['classes'].index(i.name)
  102. for i in data['image_level_labels']
  103. ]
  104. # get semantic segmentation info
  105. if 'LabelMap' in data.keys():
  106. datainfo['seg_map_path'] = data['LabelMap']
  107. # load instance info
  108. instances = []
  109. if 'Bbox' in data.keys():
  110. for idx in range(len(data['Bbox'])):
  111. bbox = data['Bbox'][idx]
  112. if self.with_hierarchy:
  113. # get leaf node name when using hierarchy classes
  114. label = data['Label'][idx].leaf_node_name
  115. label_index = self._metainfo['classes'].index(label)
  116. else:
  117. label = data['Label'][idx].name
  118. label_index = self._metainfo['classes'].index(label)
  119. instance = {}
  120. instance['bbox'] = bbox.xyxy
  121. instance['bbox_label'] = label_index
  122. if 'ignore_flag' in data.keys():
  123. # get ignore flag
  124. instance['ignore_flag'] = data['ignore_flag'][idx]
  125. else:
  126. instance['ignore_flag'] = 0
  127. if 'Polygon' in data.keys():
  128. # get polygon info
  129. polygon = data['Polygon'][idx]
  130. instance['mask'] = polygon.openmmlabformat
  131. for key in self.extra_keys:
  132. # load extra instance info
  133. instance[key] = data[key][idx]
  134. instances.append(instance)
  135. datainfo['instances'] = instances
  136. # append a standard sample in data list
  137. if len(datainfo['instances']) > 0:
  138. data_list.append(datainfo)
  139. return data_list
  140. def filter_data(self) -> List[dict]:
  141. """Filter annotations according to filter_cfg.
  142. Returns:
  143. List[dict]: Filtered results.
  144. """
  145. if self.test_mode:
  146. return self.data_list
  147. filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \
  148. if self.filter_cfg is not None else False
  149. min_size = self.filter_cfg.get('min_size', 0) \
  150. if self.filter_cfg is not None else 0
  151. valid_data_list = []
  152. for i, data_info in enumerate(self.data_list):
  153. width = data_info['width']
  154. height = data_info['height']
  155. if filter_empty_gt and len(data_info['instances']) == 0:
  156. continue
  157. if min(width, height) >= min_size:
  158. valid_data_list.append(data_info)
  159. return valid_data_list