label_file.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import base64
  2. import json
  3. import os.path
  4. from . import logger
  5. from . import PY2
  6. from . import utils
  7. from ._version import __version__
  8. class LabelFileError(Exception):
  9. pass
  10. class LabelFile(object):
  11. suffix = '.json'
  12. def __init__(self, filename=None):
  13. self.shapes = ()
  14. self.imagePath = None
  15. self.imageData = None
  16. if filename is not None:
  17. self.load(filename)
  18. self.filename = filename
  19. def load(self, filename):
  20. keys = [
  21. 'imageData',
  22. 'imagePath',
  23. 'lineColor',
  24. 'fillColor',
  25. 'shapes', # polygonal annotations
  26. 'flags', # image level flags
  27. 'imageHeight',
  28. 'imageWidth',
  29. ]
  30. try:
  31. with open(filename, 'rb' if PY2 else 'r') as f:
  32. data = json.load(f)
  33. if data['imageData'] is not None:
  34. imageData = base64.b64decode(data['imageData'])
  35. else:
  36. # relative path from label file to relative path from cwd
  37. imagePath = os.path.join(os.path.dirname(filename),
  38. data['imagePath'])
  39. with open(imagePath, 'rb') as f:
  40. imageData = f.read()
  41. flags = data.get('flags')
  42. imagePath = data['imagePath']
  43. self._check_image_height_and_width(
  44. base64.b64encode(imageData).decode('utf-8'),
  45. data.get('imageHeight'),
  46. data.get('imageWidth'),
  47. )
  48. lineColor = data['lineColor']
  49. fillColor = data['fillColor']
  50. shapes = (
  51. (
  52. s['label'],
  53. s['points'],
  54. s['line_color'],
  55. s['fill_color'],
  56. s.get('shape_type', 'polygon'),
  57. )
  58. for s in data['shapes']
  59. )
  60. except Exception as e:
  61. raise LabelFileError(e)
  62. otherData = {}
  63. for key, value in data.items():
  64. if key not in keys:
  65. otherData[key] = value
  66. # Only replace data after everything is loaded.
  67. self.flags = flags
  68. self.shapes = shapes
  69. self.imagePath = imagePath
  70. self.imageData = imageData
  71. self.lineColor = lineColor
  72. self.fillColor = fillColor
  73. self.filename = filename
  74. self.otherData = otherData
  75. @staticmethod
  76. def _check_image_height_and_width(imageData, imageHeight, imageWidth):
  77. img_arr = utils.img_b64_to_arr(imageData)
  78. if imageHeight is not None and img_arr.shape[0] != imageHeight:
  79. logger.error(
  80. 'imageHeight does not match with imageData or imagePath, '
  81. 'so getting imageHeight from actual image.'
  82. )
  83. imageHeight = img_arr.shape[0]
  84. if imageWidth is not None and img_arr.shape[1] != imageWidth:
  85. logger.error(
  86. 'imageWidth does not match with imageData or imagePath, '
  87. 'so getting imageWidth from actual image.'
  88. )
  89. imageWidth = img_arr.shape[1]
  90. return imageHeight, imageWidth
  91. def save(
  92. self,
  93. filename,
  94. shapes,
  95. imagePath,
  96. imageHeight,
  97. imageWidth,
  98. imageData=None,
  99. lineColor=None,
  100. fillColor=None,
  101. otherData=None,
  102. flags=None,
  103. ):
  104. if imageData is not None:
  105. imageData = base64.b64encode(imageData).decode('utf-8')
  106. imageHeight, imageWidth = self._check_image_height_and_width(
  107. imageData, imageHeight, imageWidth
  108. )
  109. if otherData is None:
  110. otherData = {}
  111. if flags is None:
  112. flags = {}
  113. data = dict(
  114. version=__version__,
  115. flags=flags,
  116. shapes=shapes,
  117. lineColor=lineColor,
  118. fillColor=fillColor,
  119. imagePath=imagePath,
  120. imageData=imageData,
  121. imageHeight=imageHeight,
  122. imageWidth=imageWidth,
  123. )
  124. for key, value in otherData.items():
  125. data[key] = value
  126. try:
  127. with open(filename, 'wb' if PY2 else 'w') as f:
  128. json.dump(data, f, ensure_ascii=False, indent=2)
  129. self.filename = filename
  130. except Exception as e:
  131. raise LabelFileError(e)
  132. @staticmethod
  133. def isLabelFile(filename):
  134. return os.path.splitext(filename)[1].lower() == LabelFile.suffix