소스 검색

Fix string encoding for compatibility in Python2 and 3

Kentaro Wada 7 년 전
부모
커밋
9689cd7f27
1개의 변경된 파일16개의 추가작업 그리고 4개의 파일을 삭제
  1. 16 4
      labelme/labelFile.py

+ 16 - 4
labelme/labelFile.py

@@ -17,10 +17,12 @@
 # along with Labelme.  If not, see <http://www.gnu.org/licenses/>.
 #
 
+from base64 import b64encode, b64decode
 import json
 import os.path
 
-from base64 import b64encode, b64decode
+import six
+
 
 class LabelFileError(Exception):
     pass
@@ -40,7 +42,12 @@ class LabelFile(object):
             with open(filename, 'rb') as f:
                 data = json.load(f)
                 imagePath = data['imagePath']
-                imageData = b64decode(data['imageData'])
+                if six.PY3:
+                    imageData = b64decode(data['imageData']).decode('utf-8')
+                elif six.PY2:
+                    imageData = b64decode(data['imageData'])
+                else:
+                    raise RuntimeError('Unsupported Python version.')
                 lineColor = data['lineColor']
                 fillColor = data['fillColor']
                 shapes = ((s['label'], s['points'], s['line_color'], s['fill_color'])\
@@ -58,11 +65,17 @@ class LabelFile(object):
             lineColor=None, fillColor=None):
         try:
             with open(filename, 'wb') as f:
+                if six.PY3:
+                    imageData = b64encode(imageData.encode('utf-8'))
+                elif six.PY2:
+                    imageData = b64encode(imageData)
+                else:
+                    raise RuntimeError('Unsupported Python version.')
                 json.dump(dict(
                     shapes=shapes,
                     lineColor=lineColor, fillColor=fillColor,
                     imagePath=imagePath,
-                    imageData=b64encode(imageData)),
+                    imageData=imageData),
                     f, ensure_ascii=True, indent=2)
         except Exception as e:
             raise LabelFileError(e)
@@ -70,4 +83,3 @@ class LabelFile(object):
     @staticmethod
     def isLabelFile(filename):
         return os.path.splitext(filename)[1].lower() == LabelFile.suffix
-