labelme2voc.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. #!/usr/bin/env python
  2. from __future__ import print_function
  3. import argparse
  4. import glob
  5. import io
  6. import json
  7. import os
  8. import os.path as osp
  9. import matplotlib
  10. matplotlib.use('Agg')
  11. import matplotlib.pyplot as plt
  12. import numpy as np
  13. import PIL.Image
  14. import PIL.ImagePalette
  15. import skimage.color
  16. import skimage.io
  17. import labelme
  18. from labelme.utils import label2rgb
  19. from labelme.utils import label_colormap
  20. # TODO(wkentaro): Move to labelme/utils.py
  21. # contrib
  22. # -----------------------------------------------------------------------------
  23. def labelme_shapes_to_label(img_shape, shapes, label_name_to_value):
  24. lbl = np.zeros(img_shape[:2], dtype=np.int32)
  25. for shape in shapes:
  26. polygons = shape['points']
  27. label_name = shape['label']
  28. if label_name in label_name_to_value:
  29. label_value = label_name_to_value[label_name]
  30. else:
  31. label_value = len(label_name_to_value)
  32. label_name_to_value[label_name] = label_value
  33. mask = labelme.utils.polygons_to_mask(img_shape[:2], polygons)
  34. lbl[mask] = label_value
  35. return lbl
  36. def draw_label(label, img, label_names, colormap=None):
  37. plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
  38. wspace=0, hspace=0)
  39. plt.margins(0, 0)
  40. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  41. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  42. if colormap is None:
  43. colormap = label_colormap(len(label_names))
  44. label_viz = label2rgb(
  45. label, img, n_labels=len(label_names), alpha=.5)
  46. plt.imshow(label_viz)
  47. plt.axis('off')
  48. plt_handlers = []
  49. plt_titles = []
  50. for label_value, label_name in enumerate(label_names):
  51. if label_value not in label:
  52. continue
  53. if label_name.startswith('_'):
  54. continue
  55. fc = colormap[label_value]
  56. p = plt.Rectangle((0, 0), 1, 1, fc=fc)
  57. plt_handlers.append(p)
  58. plt_titles.append(label_name)
  59. plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
  60. f = io.BytesIO()
  61. plt.savefig(f, bbox_inches='tight', pad_inches=0)
  62. plt.cla()
  63. plt.close()
  64. out_size = (img.shape[1], img.shape[0])
  65. out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
  66. out = np.asarray(out)
  67. return out
  68. # -----------------------------------------------------------------------------
  69. def main():
  70. parser = argparse.ArgumentParser(
  71. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  72. parser.add_argument('labels_file')
  73. parser.add_argument('in_dir')
  74. parser.add_argument('out_dir')
  75. args = parser.parse_args()
  76. if osp.exists(args.out_dir):
  77. print('Output directory already exists:', args.out_dir)
  78. quit(1)
  79. os.makedirs(args.out_dir)
  80. os.makedirs(osp.join(args.out_dir, 'JPEGImages'))
  81. os.makedirs(osp.join(args.out_dir, 'SegmentationClass'))
  82. os.makedirs(osp.join(args.out_dir, 'SegmentationClassVisualization'))
  83. print('Creating dataset:', args.out_dir)
  84. class_names = []
  85. class_name_to_id = {}
  86. for i, line in enumerate(open(args.labels_file).readlines()):
  87. class_id = i - 1 # starts with -1
  88. class_name = line.strip()
  89. class_name_to_id[class_name] = class_id
  90. if class_id == -1:
  91. assert class_name == '__ignore__'
  92. continue
  93. elif class_id == 0:
  94. assert class_name == '_background_'
  95. class_names.append(class_name)
  96. class_names = tuple(class_names)
  97. print('class_names:', class_names)
  98. out_class_names_file = osp.join(args.out_dir, 'class_names.txt')
  99. with open(out_class_names_file, 'w') as f:
  100. f.writelines('\n'.join(class_names))
  101. print('Saved class_names:', out_class_names_file)
  102. colormap = labelme.utils.label_colormap(255)
  103. for label_file in glob.glob(osp.join(args.in_dir, '*.json')):
  104. print('Generating dataset from:', label_file)
  105. with open(label_file) as f:
  106. base = osp.splitext(osp.basename(label_file))[0]
  107. out_img_file = osp.join(
  108. args.out_dir, 'JPEGImages', base + '.jpg')
  109. out_lbl_file = osp.join(
  110. args.out_dir, 'SegmentationClass', base + '.png')
  111. out_viz_file = osp.join(
  112. args.out_dir, 'SegmentationClassVisualization', base + '.jpg')
  113. data = json.load(f)
  114. img_file = osp.join(osp.dirname(label_file), data['imagePath'])
  115. img = skimage.io.imread(img_file)
  116. skimage.io.imsave(out_img_file, img)
  117. lbl = labelme_shapes_to_label(
  118. img_shape=img.shape,
  119. shapes=data['shapes'],
  120. label_name_to_value=class_name_to_id,
  121. )
  122. lbl_pil = PIL.Image.fromarray(lbl)
  123. # Only works with uint8 label
  124. # lbl_pil = PIL.Image.fromarray(lbl, mode='P')
  125. # lbl_pil.putpalette((colormap * 255).flatten())
  126. lbl_pil.save(out_lbl_file)
  127. viz = draw_label(
  128. lbl, img, class_names, colormap=colormap)
  129. skimage.io.imsave(out_viz_file, viz)
  130. if __name__ == '__main__':
  131. main()