labelFile.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import base64
  2. import json
  3. import os.path
  4. import sys
  5. PY2 = sys.version_info[0] == 2
  6. class LabelFileError(Exception):
  7. pass
  8. class LabelFile(object):
  9. suffix = '.json'
  10. def __init__(self, filename=None):
  11. self.shapes = ()
  12. self.imagePath = None
  13. self.imageData = None
  14. if filename is not None:
  15. self.load(filename)
  16. self.filename = filename
  17. def load(self, filename):
  18. try:
  19. with open(filename, 'rb' if PY2 else 'r') as f:
  20. data = json.load(f)
  21. if data['imageData'] is not None:
  22. imageData = base64.b64decode(data['imageData'])
  23. else:
  24. # relative path from label file to relative path from cwd
  25. imagePath = os.path.join(os.path.dirname(filename),
  26. data['imagePath'])
  27. with open(imagePath, 'rb') as f:
  28. imageData = f.read()
  29. lineColor = data['lineColor']
  30. fillColor = data['fillColor']
  31. shapes = (
  32. (s['label'], s['points'], s['line_color'], s['fill_color'])
  33. for s in data['shapes']
  34. )
  35. # Only replace data after everything is loaded.
  36. self.shapes = shapes
  37. self.imagePath = data['imagePath']
  38. self.imageData = imageData
  39. self.lineColor = lineColor
  40. self.fillColor = fillColor
  41. self.filename = filename
  42. except Exception as e:
  43. raise LabelFileError(e)
  44. def save(self, filename, shapes, imagePath, imageData=None,
  45. lineColor=None, fillColor=None):
  46. if imageData is not None:
  47. imageData = base64.b64encode(imageData).decode('utf-8')
  48. data = dict(
  49. shapes=shapes,
  50. lineColor=lineColor,
  51. fillColor=fillColor,
  52. imagePath=imagePath,
  53. imageData=imageData,
  54. )
  55. try:
  56. with open(filename, 'wb' if PY2 else 'w') as f:
  57. json.dump(data, f, ensure_ascii=True, indent=2)
  58. self.filename = filename
  59. except Exception as e:
  60. raise LabelFileError(e)
  61. @staticmethod
  62. def isLabelFile(filename):
  63. return os.path.splitext(filename)[1].lower() == LabelFile.suffix