labelme2voc.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #!/usr/bin/env python
  2. from __future__ import print_function
  3. import argparse
  4. import glob
  5. import json
  6. import os
  7. import os.path as osp
  8. import lxml.builder
  9. import lxml.etree
  10. import numpy as np
  11. import PIL.Image
  12. import labelme
  13. def main():
  14. parser = argparse.ArgumentParser(
  15. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  16. parser.add_argument('labels_file')
  17. parser.add_argument('in_dir', help='input dir with annotated files')
  18. parser.add_argument('out_dir', help='output dataset directory')
  19. args = parser.parse_args()
  20. if osp.exists(args.out_dir):
  21. print('Output directory already exists:', args.out_dir)
  22. quit(1)
  23. os.makedirs(args.out_dir)
  24. os.makedirs(osp.join(args.out_dir, 'JPEGImages'))
  25. os.makedirs(osp.join(args.out_dir, 'Annotations'))
  26. os.makedirs(osp.join(args.out_dir, 'AnnotationsVisualization'))
  27. print('Creating dataset:', args.out_dir)
  28. class_names = []
  29. class_name_to_id = {}
  30. for i, line in enumerate(open(args.labels_file).readlines()):
  31. class_id = i - 1 # starts with -1
  32. class_name = line.strip()
  33. class_name_to_id[class_name] = class_id
  34. if class_id == -1:
  35. assert class_name == '__ignore__'
  36. continue
  37. elif class_id == 0:
  38. assert class_name == '_background_'
  39. class_names.append(class_name)
  40. class_names = tuple(class_names)
  41. print('class_names:', class_names)
  42. out_class_names_file = osp.join(args.out_dir, 'class_names.txt')
  43. with open(out_class_names_file, 'w') as f:
  44. f.writelines('\n'.join(class_names))
  45. print('Saved class_names:', out_class_names_file)
  46. for label_file in glob.glob(osp.join(args.in_dir, '*.json')):
  47. print('Generating dataset from:', label_file)
  48. with open(label_file) as f:
  49. data = json.load(f)
  50. base = osp.splitext(osp.basename(label_file))[0]
  51. out_img_file = osp.join(
  52. args.out_dir, 'JPEGImages', base + '.jpg')
  53. out_xml_file = osp.join(
  54. args.out_dir, 'Annotations', base + '.xml')
  55. out_viz_file = osp.join(
  56. args.out_dir, 'AnnotationsVisualization', base + '.jpg')
  57. img_file = osp.join(osp.dirname(label_file), data['imagePath'])
  58. img = np.asarray(PIL.Image.open(img_file))
  59. PIL.Image.fromarray(img).save(out_img_file)
  60. maker = lxml.builder.ElementMaker()
  61. xml = maker.annotation(
  62. maker.folder(),
  63. maker.filename(base + '.jpg'),
  64. maker.database(), # e.g., The VOC2007 Database
  65. maker.annotation(), # e.g., Pascal VOC2007
  66. maker.image(), # e.g., flickr
  67. maker.size(
  68. maker.height(str(img.shape[0])),
  69. maker.width(str(img.shape[1])),
  70. maker.depth(str(img.shape[2])),
  71. ),
  72. maker.segmented(),
  73. )
  74. bboxes = []
  75. labels = []
  76. for shape in data['shapes']:
  77. if shape['shape_type'] != 'rectangle':
  78. print('Skipping shape: label={label}, shape_type={shape_type}'
  79. .format(**shape))
  80. continue
  81. class_name = shape['label']
  82. class_id = class_names.index(class_name)
  83. (xmin, ymin), (xmax, ymax) = shape['points']
  84. bboxes.append([xmin, ymin, xmax, ymax])
  85. labels.append(class_id)
  86. xml.append(
  87. maker.object(
  88. maker.name(shape['label']),
  89. maker.pose(),
  90. maker.truncated(),
  91. maker.difficult(),
  92. maker.bndbox(
  93. maker.xmin(str(xmin)),
  94. maker.ymin(str(ymin)),
  95. maker.xmax(str(xmax)),
  96. maker.ymax(str(ymax)),
  97. ),
  98. )
  99. )
  100. captions = [class_names[l] for l in labels]
  101. viz = labelme.utils.draw_instances(
  102. img, bboxes, labels, captions=captions
  103. )
  104. PIL.Image.fromarray(viz).save(out_viz_file)
  105. with open(out_xml_file, 'wb') as f:
  106. f.write(lxml.etree.tostring(xml, pretty_print=True))
  107. if __name__ == '__main__':
  108. main()