labelFile.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import base64
  2. import json
  3. import os.path
  4. import sys
  5. PY2 = sys.version_info[0] == 2
  6. class LabelFileError(Exception):
  7. pass
  8. class LabelFile(object):
  9. suffix = '.json'
  10. def __init__(self, filename=None):
  11. self.shapes = ()
  12. self.imagePath = None
  13. self.imageData = None
  14. if filename is not None:
  15. self.load(filename)
  16. self.filename = filename
  17. def load(self, filename):
  18. keys = ['imageData', 'imagePath', 'lineColor', 'fillColor', 'shapes']
  19. try:
  20. with open(filename, 'rb' if PY2 else 'r') as f:
  21. data = json.load(f)
  22. if data['imageData'] is not None:
  23. imageData = base64.b64decode(data['imageData'])
  24. else:
  25. # relative path from label file to relative path from cwd
  26. imagePath = os.path.join(os.path.dirname(filename),
  27. data['imagePath'])
  28. with open(imagePath, 'rb') as f:
  29. imageData = f.read()
  30. imagePath = data['imagePath']
  31. lineColor = data['lineColor']
  32. fillColor = data['fillColor']
  33. shapes = (
  34. (s['label'], s['points'], s['line_color'], s['fill_color'])
  35. for s in data['shapes']
  36. )
  37. except Exception as e:
  38. raise LabelFileError(e)
  39. otherData = {}
  40. for key, value in data.items():
  41. if key not in keys:
  42. otherData[key] = value
  43. # Only replace data after everything is loaded.
  44. self.shapes = shapes
  45. self.imagePath = imagePath
  46. self.imageData = imageData
  47. self.lineColor = lineColor
  48. self.fillColor = fillColor
  49. self.filename = filename
  50. self.otherData = otherData
  51. def save(self, filename, shapes, imagePath, imageData=None,
  52. lineColor=None, fillColor=None, otherData=None):
  53. if imageData is not None:
  54. imageData = base64.b64encode(imageData).decode('utf-8')
  55. if otherData is None:
  56. otherData = {}
  57. data = dict(
  58. shapes=shapes,
  59. lineColor=lineColor,
  60. fillColor=fillColor,
  61. imagePath=imagePath,
  62. imageData=imageData,
  63. )
  64. for key, value in otherData.items():
  65. data[key] = value
  66. try:
  67. with open(filename, 'wb' if PY2 else 'w') as f:
  68. json.dump(data, f, ensure_ascii=True, indent=2)
  69. self.filename = filename
  70. except Exception as e:
  71. raise LabelFileError(e)
  72. @staticmethod
  73. def isLabelFile(filename):
  74. return os.path.splitext(filename)[1].lower() == LabelFile.suffix