labelme2voc.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. #!/usr/bin/env python
  2. from __future__ import print_function
  3. import argparse
  4. import glob
  5. import os
  6. import os.path as osp
  7. import sys
  8. import imgviz
  9. import labelme
  10. try:
  11. import lxml.builder
  12. import lxml.etree
  13. except ImportError:
  14. print("Please install lxml:\n\n pip install lxml\n")
  15. sys.exit(1)
  16. def main():
  17. parser = argparse.ArgumentParser(
  18. formatter_class=argparse.ArgumentDefaultsHelpFormatter
  19. )
  20. parser.add_argument("input_dir", help="input annotated directory")
  21. parser.add_argument("output_dir", help="output dataset directory")
  22. parser.add_argument("--labels", help="labels file", required=True)
  23. parser.add_argument(
  24. "--noviz", help="no visualization", action="store_true"
  25. )
  26. args = parser.parse_args()
  27. if osp.exists(args.output_dir):
  28. print("Output directory already exists:", args.output_dir)
  29. sys.exit(1)
  30. os.makedirs(args.output_dir)
  31. os.makedirs(osp.join(args.output_dir, "JPEGImages"))
  32. os.makedirs(osp.join(args.output_dir, "Annotations"))
  33. if not args.noviz:
  34. os.makedirs(osp.join(args.output_dir, "AnnotationsVisualization"))
  35. print("Creating dataset:", args.output_dir)
  36. class_names = []
  37. class_name_to_id = {}
  38. for i, line in enumerate(open(args.labels).readlines()):
  39. class_id = i - 1 # starts with -1
  40. class_name = line.strip()
  41. class_name_to_id[class_name] = class_id
  42. if class_id == -1:
  43. assert class_name == "__ignore__"
  44. continue
  45. elif class_id == 0:
  46. assert class_name == "_background_"
  47. class_names.append(class_name)
  48. class_names = tuple(class_names)
  49. print("class_names:", class_names)
  50. out_class_names_file = osp.join(args.output_dir, "class_names.txt")
  51. with open(out_class_names_file, "w") as f:
  52. f.writelines("\n".join(class_names))
  53. print("Saved class_names:", out_class_names_file)
  54. for filename in glob.glob(osp.join(args.input_dir, "*.json")):
  55. print("Generating dataset from:", filename)
  56. label_file = labelme.LabelFile(filename=filename)
  57. base = osp.splitext(osp.basename(filename))[0]
  58. out_img_file = osp.join(args.output_dir, "JPEGImages", base + ".jpg")
  59. out_xml_file = osp.join(args.output_dir, "Annotations", base + ".xml")
  60. if not args.noviz:
  61. out_viz_file = osp.join(
  62. args.output_dir, "AnnotationsVisualization", base + ".jpg"
  63. )
  64. img = labelme.utils.img_data_to_arr(label_file.imageData)
  65. imgviz.io.imsave(out_img_file, img)
  66. maker = lxml.builder.ElementMaker()
  67. xml = maker.annotation(
  68. maker.folder(),
  69. maker.filename(base + ".jpg"),
  70. maker.database(), # e.g., The VOC2007 Database
  71. maker.annotation(), # e.g., Pascal VOC2007
  72. maker.image(), # e.g., flickr
  73. maker.size(
  74. maker.height(str(img.shape[0])),
  75. maker.width(str(img.shape[1])),
  76. maker.depth(str(img.shape[2])),
  77. ),
  78. maker.segmented(),
  79. )
  80. bboxes = []
  81. labels = []
  82. for shape in label_file.shapes:
  83. if shape["shape_type"] != "rectangle":
  84. print(
  85. "Skipping shape: label={label}, "
  86. "shape_type={shape_type}".format(**shape)
  87. )
  88. continue
  89. class_name = shape["label"]
  90. class_id = class_names.index(class_name)
  91. (xmin, ymin), (xmax, ymax) = shape["points"]
  92. # swap if min is larger than max.
  93. xmin, xmax = sorted([xmin, xmax])
  94. ymin, ymax = sorted([ymin, ymax])
  95. bboxes.append([ymin, xmin, ymax, xmax])
  96. labels.append(class_id)
  97. xml.append(
  98. maker.object(
  99. maker.name(shape["label"]),
  100. maker.pose(),
  101. maker.truncated(),
  102. maker.difficult(),
  103. maker.bndbox(
  104. maker.xmin(str(xmin)),
  105. maker.ymin(str(ymin)),
  106. maker.xmax(str(xmax)),
  107. maker.ymax(str(ymax)),
  108. ),
  109. )
  110. )
  111. if not args.noviz:
  112. captions = [class_names[label] for label in labels]
  113. viz = imgviz.instances2rgb(
  114. image=img,
  115. labels=labels,
  116. bboxes=bboxes,
  117. captions=captions,
  118. font_size=15,
  119. )
  120. imgviz.io.imsave(out_viz_file, viz)
  121. with open(out_xml_file, "wb") as f:
  122. f.write(lxml.etree.tostring(xml, pretty_print=True))
  123. if __name__ == "__main__":
  124. main()