瀏覽代碼

Fix image orientation in bbox_detection/labelme2voc.py

Kentaro Wada 5 年之前
父節點
當前提交
08aa0d2725
共有 1 個文件被更改,包括 10 次插入12 次删除
  1. 10 12
      examples/bbox_detection/labelme2voc.py

+ 10 - 12
examples/bbox_detection/labelme2voc.py

@@ -4,20 +4,18 @@ from __future__ import print_function
 
 import argparse
 import glob
-import json
 import os
 import os.path as osp
 import sys
 
 import imgviz
+import labelme
 try:
     import lxml.builder
     import lxml.etree
 except ImportError:
     print('Please install lxml:\n\n    pip install lxml\n')
     sys.exit(1)
-import numpy as np
-import PIL.Image
 
 
 def main():
@@ -61,11 +59,12 @@ def main():
         f.writelines('\n'.join(class_names))
     print('Saved class_names:', out_class_names_file)
 
-    for label_file in glob.glob(osp.join(args.input_dir, '*.json')):
-        print('Generating dataset from:', label_file)
-        with open(label_file) as f:
-            data = json.load(f)
-        base = osp.splitext(osp.basename(label_file))[0]
+    for filename in glob.glob(osp.join(args.input_dir, '*.json')):
+        print('Generating dataset from:', filename)
+
+        label_file = labelme.LabelFile(filename=filename)
+
+        base = osp.splitext(osp.basename(filename))[0]
         out_img_file = osp.join(
             args.output_dir, 'JPEGImages', base + '.jpg')
         out_xml_file = osp.join(
@@ -74,9 +73,8 @@ def main():
             out_viz_file = osp.join(
                 args.output_dir, 'AnnotationsVisualization', base + '.jpg')
 
-        img_file = osp.join(osp.dirname(label_file), data['imagePath'])
-        img = np.asarray(PIL.Image.open(img_file))
-        PIL.Image.fromarray(img).save(out_img_file)
+        img = labelme.utils.img_data_to_arr(label_file.imageData)
+        imgviz.io.imsave(out_img_file, img)
 
         maker = lxml.builder.ElementMaker()
         xml = maker.annotation(
@@ -95,7 +93,7 @@ def main():
 
         bboxes = []
         labels = []
-        for shape in data['shapes']:
+        for shape in label_file.shapes:
             if shape['shape_type'] != 'rectangle':
                 print('Skipping shape: label={label}, shape_type={shape_type}'
                       .format(**shape))