labelme2coco.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #!/usr/bin/env python
  2. import argparse
  3. import collections
  4. import datetime
  5. import glob
  6. import json
  7. import os
  8. import os.path as osp
  9. import sys
  10. import uuid
  11. import numpy as np
  12. import PIL.Image
  13. import labelme
  14. try:
  15. import pycocotools.mask
  16. except ImportError:
  17. print('Please install pycocotools:\n\n pip install pycocotools\n')
  18. sys.exit(1)
  19. def main():
  20. parser = argparse.ArgumentParser(
  21. formatter_class=argparse.ArgumentDefaultsHelpFormatter
  22. )
  23. parser.add_argument('input_dir', help='input annotated directory')
  24. parser.add_argument('output_dir', help='output dataset directory')
  25. parser.add_argument('--labels', help='labels file', required=True)
  26. args = parser.parse_args()
  27. if osp.exists(args.output_dir):
  28. print('Output directory already exists:', args.output_dir)
  29. sys.exit(1)
  30. os.makedirs(args.output_dir)
  31. os.makedirs(osp.join(args.output_dir, 'JPEGImages'))
  32. print('Creating dataset:', args.output_dir)
  33. now = datetime.datetime.now()
  34. data = dict(
  35. info=dict(
  36. description=None,
  37. url=None,
  38. version=None,
  39. year=now.year,
  40. contributor=None,
  41. date_created=now.strftime('%Y-%m-%d %H:%M:%S.%f'),
  42. ),
  43. licenses=[dict(
  44. url=None,
  45. id=0,
  46. name=None,
  47. )],
  48. images=[
  49. # license, url, file_name, height, width, date_captured, id
  50. ],
  51. type='instances',
  52. annotations=[
  53. # segmentation, area, iscrowd, image_id, bbox, category_id, id
  54. ],
  55. categories=[
  56. # supercategory, id, name
  57. ],
  58. )
  59. class_name_to_id = {}
  60. for i, line in enumerate(open(args.labels).readlines()):
  61. class_id = i - 1 # starts with -1
  62. class_name = line.strip()
  63. if class_id == -1:
  64. assert class_name == '__ignore__'
  65. continue
  66. class_name_to_id[class_name] = class_id
  67. data['categories'].append(dict(
  68. supercategory=None,
  69. id=class_id,
  70. name=class_name,
  71. ))
  72. out_ann_file = osp.join(args.output_dir, 'annotations.json')
  73. label_files = glob.glob(osp.join(args.input_dir, '*.json'))
  74. for image_id, label_file in enumerate(label_files):
  75. print('Generating dataset from:', label_file)
  76. with open(label_file) as f:
  77. label_data = json.load(f)
  78. base = osp.splitext(osp.basename(label_file))[0]
  79. out_img_file = osp.join(
  80. args.output_dir, 'JPEGImages', base + '.jpg'
  81. )
  82. img_file = osp.join(
  83. osp.dirname(label_file), label_data['imagePath']
  84. )
  85. img = np.asarray(PIL.Image.open(img_file))
  86. PIL.Image.fromarray(img).save(out_img_file)
  87. data['images'].append(dict(
  88. license=0,
  89. url=None,
  90. file_name=osp.relpath(out_img_file, osp.dirname(out_ann_file)),
  91. height=img.shape[0],
  92. width=img.shape[1],
  93. date_captured=None,
  94. id=image_id,
  95. ))
  96. masks = {} # for area
  97. segmentations = collections.defaultdict(list) # for segmentation
  98. for shape in label_data['shapes']:
  99. points = shape['points']
  100. label = shape['label']
  101. group_id = shape.get('group_id')
  102. shape_type = shape.get('shape_type')
  103. mask = labelme.utils.shape_to_mask(
  104. img.shape[:2], points, shape_type
  105. )
  106. if group_id is None:
  107. group_id = uuid.uuid1()
  108. instance = (label, group_id)
  109. if instance in masks:
  110. masks[instance] = masks[instance] | mask
  111. else:
  112. masks[instance] = mask
  113. points = np.asarray(points).flatten().tolist()
  114. segmentations[instance].append(points)
  115. segmentations = dict(segmentations)
  116. for instance, mask in masks.items():
  117. cls_name, group_id = instance
  118. if cls_name not in class_name_to_id:
  119. continue
  120. cls_id = class_name_to_id[cls_name]
  121. mask = np.asfortranarray(mask.astype(np.uint8))
  122. mask = pycocotools.mask.encode(mask)
  123. area = float(pycocotools.mask.area(mask))
  124. bbox = pycocotools.mask.toBbox(mask).flatten().tolist()
  125. data['annotations'].append(dict(
  126. id=len(data['annotations']),
  127. image_id=image_id,
  128. category_id=cls_id,
  129. segmentation=segmentations[instance],
  130. area=area,
  131. bbox=bbox,
  132. iscrowd=0,
  133. ))
  134. with open(out_ann_file, 'w') as f:
  135. json.dump(data, f)
  136. if __name__ == '__main__':
  137. main()