labelme2coco.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #!/usr/bin/env python
  2. import argparse
  3. import collections
  4. import datetime
  5. import glob
  6. import json
  7. import os
  8. import os.path as osp
  9. import sys
  10. import numpy as np
  11. import PIL.Image
  12. import labelme
  13. try:
  14. import pycocotools.mask
  15. except ImportError:
  16. print('Please install pycocotools:\n\n pip install pycocotools\n')
  17. sys.exit(1)
  18. def main():
  19. parser = argparse.ArgumentParser(
  20. formatter_class=argparse.ArgumentDefaultsHelpFormatter
  21. )
  22. parser.add_argument('input_dir', help='input annotated directory')
  23. parser.add_argument('output_dir', help='output dataset directory')
  24. parser.add_argument('--labels', help='labels file', required=True)
  25. args = parser.parse_args()
  26. if osp.exists(args.output_dir):
  27. print('Output directory already exists:', args.output_dir)
  28. sys.exit(1)
  29. os.makedirs(args.output_dir)
  30. os.makedirs(osp.join(args.output_dir, 'JPEGImages'))
  31. print('Creating dataset:', args.output_dir)
  32. now = datetime.datetime.now()
  33. data = dict(
  34. info=dict(
  35. description=None,
  36. url=None,
  37. version=None,
  38. year=now.year,
  39. contributor=None,
  40. date_created=now.strftime('%Y-%m-%d %H:%M:%S.%f'),
  41. ),
  42. licenses=[dict(
  43. url=None,
  44. id=0,
  45. name=None,
  46. )],
  47. images=[
  48. # license, url, file_name, height, width, date_captured, id
  49. ],
  50. type='instances',
  51. annotations=[
  52. # segmentation, area, iscrowd, image_id, bbox, category_id, id
  53. ],
  54. categories=[
  55. # supercategory, id, name
  56. ],
  57. )
  58. class_name_to_id = {}
  59. for i, line in enumerate(open(args.labels).readlines()):
  60. class_id = i - 1 # starts with -1
  61. class_name = line.strip()
  62. if class_id == -1:
  63. assert class_name == '__ignore__'
  64. continue
  65. class_name_to_id[class_name] = class_id
  66. data['categories'].append(dict(
  67. supercategory=None,
  68. id=class_id,
  69. name=class_name,
  70. ))
  71. out_ann_file = osp.join(args.output_dir, 'annotations.json')
  72. label_files = glob.glob(osp.join(args.input_dir, '*.json'))
  73. for image_id, label_file in enumerate(label_files):
  74. print('Generating dataset from:', label_file)
  75. with open(label_file) as f:
  76. label_data = json.load(f)
  77. base = osp.splitext(osp.basename(label_file))[0]
  78. out_img_file = osp.join(
  79. args.output_dir, 'JPEGImages', base + '.jpg'
  80. )
  81. img_file = osp.join(
  82. osp.dirname(label_file), label_data['imagePath']
  83. )
  84. img = np.asarray(PIL.Image.open(img_file))
  85. PIL.Image.fromarray(img).save(out_img_file)
  86. data['images'].append(dict(
  87. license=0,
  88. url=None,
  89. file_name=osp.relpath(out_img_file, osp.dirname(out_ann_file)),
  90. height=img.shape[0],
  91. width=img.shape[1],
  92. date_captured=None,
  93. id=image_id,
  94. ))
  95. masks = {} # for area
  96. segmentations = collections.defaultdict(list) # for segmentation
  97. for shape in label_data['shapes']:
  98. points = shape['points']
  99. label = shape['label']
  100. shape_type = shape.get('shape_type', None)
  101. mask = labelme.utils.shape_to_mask(
  102. img.shape[:2], points, shape_type
  103. )
  104. if label in masks:
  105. masks[label] = masks[label] | mask
  106. else:
  107. masks[label] = mask
  108. points = np.asarray(points).flatten().tolist()
  109. segmentations[label].append(points)
  110. for label, mask in masks.items():
  111. cls_name = label.split('-')[0]
  112. if cls_name not in class_name_to_id:
  113. continue
  114. cls_id = class_name_to_id[cls_name]
  115. mask = np.asfortranarray(mask.astype(np.uint8))
  116. mask = pycocotools.mask.encode(mask)
  117. area = float(pycocotools.mask.area(mask))
  118. bbox = pycocotools.mask.toBbox(mask).flatten().tolist()
  119. data['annotations'].append(dict(
  120. id=len(data['annotations']),
  121. image_id=image_id,
  122. category_id=cls_id,
  123. segmentation=segmentations[label],
  124. area=area,
  125. bbox=bbox,
  126. iscrowd=0,
  127. ))
  128. with open(out_ann_file, 'w') as f:
  129. json.dump(data, f)
  130. if __name__ == '__main__':
  131. main()