label_file.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import base64
  2. import json
  3. import os.path
  4. import sys
  5. from . import logger
  6. from . import utils
  7. from ._version import __version__
  8. PY2 = sys.version_info[0] == 2
  9. class LabelFileError(Exception):
  10. pass
  11. class LabelFile(object):
  12. suffix = '.json'
  13. def __init__(self, filename=None):
  14. self.shapes = ()
  15. self.imagePath = None
  16. self.imageData = None
  17. if filename is not None:
  18. self.load(filename)
  19. self.filename = filename
  20. def load(self, filename):
  21. keys = [
  22. 'imageData',
  23. 'imagePath',
  24. 'lineColor',
  25. 'fillColor',
  26. 'shapes', # polygonal annotations
  27. 'flags', # image level flags
  28. 'imageHeight',
  29. 'imageWidth',
  30. ]
  31. try:
  32. with open(filename, 'rb' if PY2 else 'r') as f:
  33. data = json.load(f)
  34. if data['imageData'] is not None:
  35. imageData = base64.b64decode(data['imageData'])
  36. else:
  37. # relative path from label file to relative path from cwd
  38. imagePath = os.path.join(os.path.dirname(filename),
  39. data['imagePath'])
  40. with open(imagePath, 'rb') as f:
  41. imageData = f.read()
  42. flags = data.get('flags')
  43. imagePath = data['imagePath']
  44. self._check_image_height_and_width(
  45. base64.b64encode(imageData).decode('utf-8'),
  46. data.get('imageHeight'),
  47. data.get('imageWidth'),
  48. )
  49. lineColor = data['lineColor']
  50. fillColor = data['fillColor']
  51. shapes = (
  52. (
  53. s['label'],
  54. s['points'],
  55. s['line_color'],
  56. s['fill_color'],
  57. s.get('shape_type', 'polygon'),
  58. )
  59. for s in data['shapes']
  60. )
  61. except Exception as e:
  62. raise LabelFileError(e)
  63. otherData = {}
  64. for key, value in data.items():
  65. if key not in keys:
  66. otherData[key] = value
  67. # Only replace data after everything is loaded.
  68. self.flags = flags
  69. self.shapes = shapes
  70. self.imagePath = imagePath
  71. self.imageData = imageData
  72. self.lineColor = lineColor
  73. self.fillColor = fillColor
  74. self.filename = filename
  75. self.otherData = otherData
  76. @staticmethod
  77. def _check_image_height_and_width(imageData, imageHeight, imageWidth):
  78. img_arr = utils.img_b64_to_arr(imageData)
  79. if imageHeight is not None and img_arr.shape[0] != imageHeight:
  80. logger.error(
  81. 'imageHeight does not match with imageData or imagePath, '
  82. 'so getting imageHeight from actual image.'
  83. )
  84. imageHeight = img_arr.shape[0]
  85. if imageWidth is not None and img_arr.shape[1] != imageWidth:
  86. logger.error(
  87. 'imageWidth does not match with imageData or imagePath, '
  88. 'so getting imageWidth from actual image.'
  89. )
  90. imageWidth = img_arr.shape[1]
  91. return imageHeight, imageWidth
  92. def save(
  93. self,
  94. filename,
  95. shapes,
  96. imagePath,
  97. imageHeight,
  98. imageWidth,
  99. imageData=None,
  100. lineColor=None,
  101. fillColor=None,
  102. otherData=None,
  103. flags=None,
  104. ):
  105. if imageData is not None:
  106. imageData = base64.b64encode(imageData).decode('utf-8')
  107. imageHeight, imageWidth = self._check_image_height_and_width(
  108. imageData, imageHeight, imageWidth
  109. )
  110. if otherData is None:
  111. otherData = {}
  112. if flags is None:
  113. flags = []
  114. data = dict(
  115. version=__version__,
  116. flags=flags,
  117. shapes=shapes,
  118. lineColor=lineColor,
  119. fillColor=fillColor,
  120. imagePath=imagePath,
  121. imageData=imageData,
  122. imageHeight=imageHeight,
  123. imageWidth=imageWidth,
  124. )
  125. for key, value in otherData.items():
  126. data[key] = value
  127. try:
  128. with open(filename, 'wb' if PY2 else 'w') as f:
  129. json.dump(data, f, ensure_ascii=False, indent=2)
  130. self.filename = filename
  131. except Exception as e:
  132. raise LabelFileError(e)
  133. @staticmethod
  134. def isLabelFile(filename):
  135. return os.path.splitext(filename)[1].lower() == LabelFile.suffix