소스 검색

Fix error on saving/loading when using Python3

Kentaro Wada 7 년 전
부모
커밋
2b020e270f
1개의 변경된 파일11개의 추가작업 그리고 20개의 파일을 삭제
  1. 11 20
      labelme/labelFile.py

+ 11 - 20
labelme/labelFile.py

@@ -39,15 +39,10 @@ class LabelFile(object):
 
     def load(self, filename):
         try:
-            with open(filename, 'rb') as f:
+            with open(filename, 'rb' if six.PY2 else 'r') as f:
                 data = json.load(f)
                 imagePath = data['imagePath']
-                if six.PY3:
-                    imageData = b64decode(data['imageData']).decode('utf-8')
-                elif six.PY2:
-                    imageData = b64decode(data['imageData'])
-                else:
-                    raise RuntimeError('Unsupported Python version.')
+                imageData = b64decode(data['imageData'])
                 lineColor = data['lineColor']
                 fillColor = data['fillColor']
                 shapes = ((s['label'], s['points'], s['line_color'], s['fill_color'])\
@@ -63,20 +58,16 @@ class LabelFile(object):
 
     def save(self, filename, shapes, imagePath, imageData,
             lineColor=None, fillColor=None):
+        data = dict(
+            shapes=shapes,
+            lineColor=lineColor,
+            fillColor=fillColor,
+            imagePath=imagePath,
+            imageData=b64encode(imageData).decode('utf-8'),
+        )
         try:
-            with open(filename, 'wb') as f:
-                if six.PY3:
-                    imageData = b64encode(imageData.encode('utf-8'))
-                elif six.PY2:
-                    imageData = b64encode(imageData)
-                else:
-                    raise RuntimeError('Unsupported Python version.')
-                json.dump(dict(
-                    shapes=shapes,
-                    lineColor=lineColor, fillColor=fillColor,
-                    imagePath=imagePath,
-                    imageData=imageData),
-                    f, ensure_ascii=True, indent=2)
+            with open(filename, 'wb' if six.PY2 else 'w') as f:
+                json.dump(data, f, ensure_ascii=True, indent=2)
         except Exception as e:
             raise LabelFileError(e)