labelme2voc.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #!/usr/bin/env python
  2. from __future__ import print_function
  3. import argparse
  4. import glob
  5. import json
  6. import os
  7. import os.path as osp
  8. import numpy as np
  9. import PIL.Image
  10. import PIL.ImagePalette
  11. import labelme
  12. def main():
  13. parser = argparse.ArgumentParser(
  14. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  15. parser.add_argument('labels_file')
  16. parser.add_argument('in_dir')
  17. parser.add_argument('out_dir')
  18. args = parser.parse_args()
  19. if osp.exists(args.out_dir):
  20. print('Output directory already exists:', args.out_dir)
  21. quit(1)
  22. os.makedirs(args.out_dir)
  23. os.makedirs(osp.join(args.out_dir, 'JPEGImages'))
  24. os.makedirs(osp.join(args.out_dir, 'SegmentationClass'))
  25. os.makedirs(osp.join(args.out_dir, 'SegmentationClassVisualization'))
  26. print('Creating dataset:', args.out_dir)
  27. class_names = []
  28. class_name_to_id = {}
  29. for i, line in enumerate(open(args.labels_file).readlines()):
  30. class_id = i - 1 # starts with -1
  31. class_name = line.strip()
  32. class_name_to_id[class_name] = class_id
  33. if class_id == -1:
  34. assert class_name == '__ignore__'
  35. continue
  36. elif class_id == 0:
  37. assert class_name == '_background_'
  38. class_names.append(class_name)
  39. class_names = tuple(class_names)
  40. print('class_names:', class_names)
  41. out_class_names_file = osp.join(args.out_dir, 'class_names.txt')
  42. with open(out_class_names_file, 'w') as f:
  43. f.writelines('\n'.join(class_names))
  44. print('Saved class_names:', out_class_names_file)
  45. colormap = labelme.utils.label_colormap(255)
  46. for label_file in glob.glob(osp.join(args.in_dir, '*.json')):
  47. print('Generating dataset from:', label_file)
  48. with open(label_file) as f:
  49. base = osp.splitext(osp.basename(label_file))[0]
  50. out_img_file = osp.join(
  51. args.out_dir, 'JPEGImages', base + '.jpg')
  52. out_lbl_file = osp.join(
  53. args.out_dir, 'SegmentationClass', base + '.png')
  54. out_viz_file = osp.join(
  55. args.out_dir, 'SegmentationClassVisualization', base + '.jpg')
  56. data = json.load(f)
  57. img_file = osp.join(osp.dirname(label_file), data['imagePath'])
  58. img = np.asarray(PIL.Image.open(img_file))
  59. PIL.Image.fromarray(img).save(out_img_file)
  60. lbl = labelme.utils.shapes_to_label(
  61. img_shape=img.shape,
  62. shapes=data['shapes'],
  63. label_name_to_value=class_name_to_id,
  64. )
  65. lbl_pil = PIL.Image.fromarray(lbl)
  66. # Only works with uint8 label
  67. # lbl_pil = PIL.Image.fromarray(lbl, mode='P')
  68. # lbl_pil.putpalette((colormap * 255).flatten())
  69. lbl_pil.save(out_lbl_file)
  70. label_names = ['%d: %s' % (cls_id, cls_name)
  71. for cls_id, cls_name in enumerate(class_names)]
  72. viz = labelme.utils.draw_label(
  73. lbl, img, label_names, colormap=colormap)
  74. PIL.Image.fromarray(viz).save(out_viz_file)
  75. if __name__ == '__main__':
  76. main()