label_file.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import base64
  2. import json
  3. import os.path
  4. import sys
  5. from ._version import __version__
  6. PY2 = sys.version_info[0] == 2
  7. class LabelFileError(Exception):
  8. pass
  9. class LabelFile(object):
  10. suffix = '.json'
  11. def __init__(self, filename=None):
  12. self.shapes = ()
  13. self.imagePath = None
  14. self.imageData = None
  15. if filename is not None:
  16. self.load(filename)
  17. self.filename = filename
  18. def load(self, filename):
  19. keys = [
  20. 'imageData',
  21. 'imagePath',
  22. 'lineColor',
  23. 'fillColor',
  24. 'shapes', # polygonal annotations
  25. 'flags', # image level flags
  26. ]
  27. try:
  28. with open(filename, 'rb' if PY2 else 'r') as f:
  29. data = json.load(f)
  30. if data['imageData'] is not None:
  31. imageData = base64.b64decode(data['imageData'])
  32. else:
  33. # relative path from label file to relative path from cwd
  34. imagePath = os.path.join(os.path.dirname(filename),
  35. data['imagePath'])
  36. with open(imagePath, 'rb') as f:
  37. imageData = f.read()
  38. flags = data.get('flags')
  39. imagePath = data['imagePath']
  40. lineColor = data['lineColor']
  41. fillColor = data['fillColor']
  42. shapes = (
  43. (
  44. s['label'],
  45. s['points'],
  46. s['line_color'],
  47. s['fill_color'],
  48. s.get('shape_type'),
  49. )
  50. for s in data['shapes']
  51. )
  52. except Exception as e:
  53. raise LabelFileError(e)
  54. otherData = {}
  55. for key, value in data.items():
  56. if key not in keys:
  57. otherData[key] = value
  58. # Only replace data after everything is loaded.
  59. self.flags = flags
  60. self.shapes = shapes
  61. self.imagePath = imagePath
  62. self.imageData = imageData
  63. self.lineColor = lineColor
  64. self.fillColor = fillColor
  65. self.filename = filename
  66. self.otherData = otherData
  67. def save(self, filename, shapes, imagePath, imageData=None,
  68. lineColor=None, fillColor=None, otherData=None,
  69. flags=None):
  70. if imageData is not None:
  71. imageData = base64.b64encode(imageData).decode('utf-8')
  72. if otherData is None:
  73. otherData = {}
  74. if flags is None:
  75. flags = []
  76. data = dict(
  77. version=__version__,
  78. flags=flags,
  79. shapes=shapes,
  80. lineColor=lineColor,
  81. fillColor=fillColor,
  82. imagePath=imagePath,
  83. imageData=imageData,
  84. )
  85. for key, value in otherData.items():
  86. data[key] = value
  87. try:
  88. with open(filename, 'wb' if PY2 else 'w') as f:
  89. json.dump(data, f, ensure_ascii=False, indent=2)
  90. self.filename = filename
  91. except Exception as e:
  92. raise LabelFileError(e)
  93. @staticmethod
  94. def isLabelFile(filename):
  95. return os.path.splitext(filename)[1].lower() == LabelFile.suffix