label_file.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import base64
  2. import io
  3. import json
  4. import os.path as osp
  5. import PIL.Image
  6. from labelme._version import __version__
  7. from labelme.logger import logger
  8. from labelme import PY2
  9. from labelme import utils
  10. class LabelFileError(Exception):
  11. pass
  12. class LabelFile(object):
  13. suffix = '.json'
  14. def __init__(self, filename=None):
  15. self.shapes = ()
  16. self.imagePath = None
  17. self.imageData = None
  18. if filename is not None:
  19. self.load(filename)
  20. self.filename = filename
  21. @staticmethod
  22. def load_image_file(filename):
  23. try:
  24. image_pil = PIL.Image.open(filename)
  25. except IOError:
  26. return
  27. # apply orientation to image according to exif
  28. image_pil = utils.apply_exif_orientation(image_pil)
  29. with io.BytesIO() as f:
  30. ext = osp.splitext(filename)[1].lower()
  31. if ext in ['.jpg', '.jpeg']:
  32. image_pil.save(f, format='JPEG')
  33. else:
  34. image_pil.save(f, format='PNG')
  35. f.seek(0)
  36. return f.read()
  37. def load(self, filename):
  38. keys = [
  39. 'imageData',
  40. 'imagePath',
  41. 'lineColor',
  42. 'fillColor',
  43. 'shapes', # polygonal annotations
  44. 'flags', # image level flags
  45. 'imageHeight',
  46. 'imageWidth',
  47. ]
  48. try:
  49. with open(filename, 'rb' if PY2 else 'r') as f:
  50. data = json.load(f)
  51. if data['imageData'] is not None:
  52. imageData = base64.b64decode(data['imageData'])
  53. else:
  54. # relative path from label file to relative path from cwd
  55. imagePath = osp.join(osp.dirname(filename), data['imagePath'])
  56. imageData = self.load_image_file(imagePath)
  57. flags = data.get('flags')
  58. imagePath = data['imagePath']
  59. self._check_image_height_and_width(
  60. base64.b64encode(imageData).decode('utf-8'),
  61. data.get('imageHeight'),
  62. data.get('imageWidth'),
  63. )
  64. lineColor = data['lineColor']
  65. fillColor = data['fillColor']
  66. shapes = (
  67. (
  68. s['label'],
  69. s['points'],
  70. s['line_color'],
  71. s['fill_color'],
  72. s.get('shape_type', 'polygon'),
  73. )
  74. for s in data['shapes']
  75. )
  76. except Exception as e:
  77. raise LabelFileError(e)
  78. otherData = {}
  79. for key, value in data.items():
  80. if key not in keys:
  81. otherData[key] = value
  82. # Only replace data after everything is loaded.
  83. self.flags = flags
  84. self.shapes = shapes
  85. self.imagePath = imagePath
  86. self.imageData = imageData
  87. self.lineColor = lineColor
  88. self.fillColor = fillColor
  89. self.filename = filename
  90. self.otherData = otherData
  91. @staticmethod
  92. def _check_image_height_and_width(imageData, imageHeight, imageWidth):
  93. img_arr = utils.img_b64_to_arr(imageData)
  94. if imageHeight is not None and img_arr.shape[0] != imageHeight:
  95. logger.error(
  96. 'imageHeight does not match with imageData or imagePath, '
  97. 'so getting imageHeight from actual image.'
  98. )
  99. imageHeight = img_arr.shape[0]
  100. if imageWidth is not None and img_arr.shape[1] != imageWidth:
  101. logger.error(
  102. 'imageWidth does not match with imageData or imagePath, '
  103. 'so getting imageWidth from actual image.'
  104. )
  105. imageWidth = img_arr.shape[1]
  106. return imageHeight, imageWidth
  107. def save(
  108. self,
  109. filename,
  110. shapes,
  111. imagePath,
  112. imageHeight,
  113. imageWidth,
  114. imageData=None,
  115. lineColor=None,
  116. fillColor=None,
  117. otherData=None,
  118. flags=None,
  119. ):
  120. if imageData is not None:
  121. imageData = base64.b64encode(imageData).decode('utf-8')
  122. imageHeight, imageWidth = self._check_image_height_and_width(
  123. imageData, imageHeight, imageWidth
  124. )
  125. if otherData is None:
  126. otherData = {}
  127. if flags is None:
  128. flags = {}
  129. data = dict(
  130. version=__version__,
  131. flags=flags,
  132. shapes=shapes,
  133. lineColor=lineColor,
  134. fillColor=fillColor,
  135. imagePath=imagePath,
  136. imageData=imageData,
  137. imageHeight=imageHeight,
  138. imageWidth=imageWidth,
  139. )
  140. for key, value in otherData.items():
  141. data[key] = value
  142. try:
  143. with open(filename, 'wb' if PY2 else 'w') as f:
  144. json.dump(data, f, ensure_ascii=False, indent=2)
  145. self.filename = filename
  146. except Exception as e:
  147. raise LabelFileError(e)
  148. @staticmethod
  149. def is_label_file(filename):
  150. return osp.splitext(filename)[1].lower() == LabelFile.suffix