labelme2coco.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #!/usr/bin/env python
  2. import argparse
  3. import collections
  4. import datetime
  5. import glob
  6. import json
  7. import os
  8. import os.path as osp
  9. import sys
  10. import uuid
  11. import imgviz
  12. import numpy as np
  13. import labelme
  14. try:
  15. import pycocotools.mask
  16. except ImportError:
  17. print("Please install pycocotools:\n\n pip install pycocotools\n")
  18. sys.exit(1)
  19. def main():
  20. parser = argparse.ArgumentParser(
  21. formatter_class=argparse.ArgumentDefaultsHelpFormatter
  22. )
  23. parser.add_argument("input_dir", help="input annotated directory")
  24. parser.add_argument("output_dir", help="output dataset directory")
  25. parser.add_argument("--labels", help="labels file", required=True)
  26. parser.add_argument(
  27. "--noviz", help="no visualization", action="store_true"
  28. )
  29. args = parser.parse_args()
  30. if osp.exists(args.output_dir):
  31. print("Output directory already exists:", args.output_dir)
  32. sys.exit(1)
  33. os.makedirs(args.output_dir)
  34. os.makedirs(osp.join(args.output_dir, "JPEGImages"))
  35. if not args.noviz:
  36. os.makedirs(osp.join(args.output_dir, "Visualization"))
  37. print("Creating dataset:", args.output_dir)
  38. now = datetime.datetime.now()
  39. data = dict(
  40. info=dict(
  41. description=None,
  42. url=None,
  43. version=None,
  44. year=now.year,
  45. contributor=None,
  46. date_created=now.strftime("%Y-%m-%d %H:%M:%S.%f"),
  47. ),
  48. licenses=[
  49. dict(
  50. url=None,
  51. id=0,
  52. name=None,
  53. )
  54. ],
  55. images=[
  56. # license, url, file_name, height, width, date_captured, id
  57. ],
  58. type="instances",
  59. annotations=[
  60. # segmentation, area, iscrowd, image_id, bbox, category_id, id
  61. ],
  62. categories=[
  63. # supercategory, id, name
  64. ],
  65. )
  66. class_name_to_id = {}
  67. for i, line in enumerate(open(args.labels).readlines()):
  68. class_id = i - 1 # starts with -1
  69. class_name = line.strip()
  70. if class_id == -1:
  71. assert class_name == "__ignore__"
  72. continue
  73. class_name_to_id[class_name] = class_id
  74. data["categories"].append(
  75. dict(
  76. supercategory=None,
  77. id=class_id,
  78. name=class_name,
  79. )
  80. )
  81. out_ann_file = osp.join(args.output_dir, "annotations.json")
  82. label_files = glob.glob(osp.join(args.input_dir, "*.json"))
  83. for image_id, filename in enumerate(label_files):
  84. print("Generating dataset from:", filename)
  85. label_file = labelme.LabelFile(filename=filename)
  86. base = osp.splitext(osp.basename(filename))[0]
  87. out_img_file = osp.join(args.output_dir, "JPEGImages", base + ".jpg")
  88. img = labelme.utils.img_data_to_arr(label_file.imageData)
  89. imgviz.io.imsave(out_img_file, img)
  90. data["images"].append(
  91. dict(
  92. license=0,
  93. url=None,
  94. file_name=osp.relpath(out_img_file, osp.dirname(out_ann_file)),
  95. height=img.shape[0],
  96. width=img.shape[1],
  97. date_captured=None,
  98. id=image_id,
  99. )
  100. )
  101. masks = {} # for area
  102. segmentations = collections.defaultdict(list) # for segmentation
  103. for shape in label_file.shapes:
  104. points = shape["points"]
  105. label = shape["label"]
  106. group_id = shape.get("group_id")
  107. shape_type = shape.get("shape_type", "polygon")
  108. mask = labelme.utils.shape_to_mask(
  109. img.shape[:2], points, shape_type
  110. )
  111. if group_id is None:
  112. group_id = uuid.uuid1()
  113. instance = (label, group_id)
  114. if instance in masks:
  115. masks[instance] = masks[instance] | mask
  116. else:
  117. masks[instance] = mask
  118. if shape_type == "rectangle":
  119. (x1, y1), (x2, y2) = points
  120. x1, x2 = sorted([x1, x2])
  121. y1, y2 = sorted([y1, y2])
  122. points = [x1, y1, x2, y1, x2, y2, x1, y2]
  123. if shape_type == "circle":
  124. (x1, y1), (x2, y2) = points
  125. r = np.linalg.norm([x2 - x1, y2 - y1])
  126. # r(1-cos(a/2))<x, a=2*pi/N => N>pi/arccos(1-x/r)
  127. # x: tolerance of the gap between the arc and the line segment
  128. n_points_circle = max(int(np.pi / np.arccos(1 - 1 / r)), 12)
  129. i = np.arange(n_points_circle)
  130. x = x1 + r * np.sin(2 * np.pi / n_points_circle * i)
  131. y = y1 + r * np.cos(2 * np.pi / n_points_circle * i)
  132. points = np.stack((x, y), axis=1).flatten().tolist()
  133. else:
  134. points = np.asarray(points).flatten().tolist()
  135. segmentations[instance].append(points)
  136. segmentations = dict(segmentations)
  137. for instance, mask in masks.items():
  138. cls_name, group_id = instance
  139. if cls_name not in class_name_to_id:
  140. continue
  141. cls_id = class_name_to_id[cls_name]
  142. mask = np.asfortranarray(mask.astype(np.uint8))
  143. mask = pycocotools.mask.encode(mask)
  144. area = float(pycocotools.mask.area(mask))
  145. bbox = pycocotools.mask.toBbox(mask).flatten().tolist()
  146. data["annotations"].append(
  147. dict(
  148. id=len(data["annotations"]),
  149. image_id=image_id,
  150. category_id=cls_id,
  151. segmentation=segmentations[instance],
  152. area=area,
  153. bbox=bbox,
  154. iscrowd=0,
  155. )
  156. )
  157. if not args.noviz:
  158. viz = img
  159. if masks:
  160. labels, captions, masks = zip(
  161. *[
  162. (class_name_to_id[cnm], cnm, msk)
  163. for (cnm, gid), msk in masks.items()
  164. if cnm in class_name_to_id
  165. ]
  166. )
  167. viz = imgviz.instances2rgb(
  168. image=img,
  169. labels=labels,
  170. masks=masks,
  171. captions=captions,
  172. font_size=15,
  173. line_width=2,
  174. )
  175. out_viz_file = osp.join(
  176. args.output_dir, "Visualization", base + ".jpg"
  177. )
  178. imgviz.io.imsave(out_viz_file, viz)
  179. with open(out_ann_file, "w") as f:
  180. json.dump(data, f)
  181. if __name__ == "__main__":
  182. main()