label_file.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import base64
  2. import io
  3. import json
  4. import os.path as osp
  5. import PIL.Image
  6. from labelme._version import __version__
  7. from labelme.logger import logger
  8. from labelme import PY2
  9. from labelme import QT4
  10. from labelme import utils
  11. class LabelFileError(Exception):
  12. pass
  13. class LabelFile(object):
  14. suffix = '.json'
  15. def __init__(self, filename=None):
  16. self.shapes = ()
  17. self.imagePath = None
  18. self.imageData = None
  19. if filename is not None:
  20. self.load(filename)
  21. self.filename = filename
  22. @staticmethod
  23. def load_image_file(filename):
  24. try:
  25. image_pil = PIL.Image.open(filename)
  26. except IOError:
  27. logger.error('Failed opening image file: {}'.format(filename))
  28. return
  29. # apply orientation to image according to exif
  30. image_pil = utils.apply_exif_orientation(image_pil)
  31. with io.BytesIO() as f:
  32. ext = osp.splitext(filename)[1].lower()
  33. if PY2 and QT4:
  34. format = 'PNG'
  35. elif ext in ['.jpg', '.jpeg']:
  36. format = 'JPEG'
  37. else:
  38. format = 'PNG'
  39. image_pil.save(f, format=format)
  40. f.seek(0)
  41. return f.read()
  42. def load(self, filename):
  43. keys = [
  44. 'imageData',
  45. 'imagePath',
  46. 'lineColor',
  47. 'fillColor',
  48. 'shapes', # polygonal annotations
  49. 'flags', # image level flags
  50. 'imageHeight',
  51. 'imageWidth',
  52. ]
  53. try:
  54. with open(filename, 'rb' if PY2 else 'r') as f:
  55. data = json.load(f)
  56. if data['imageData'] is not None:
  57. imageData = base64.b64decode(data['imageData'])
  58. if PY2 and QT4:
  59. imageData = utils.img_data_to_png_data(imageData)
  60. else:
  61. # relative path from label file to relative path from cwd
  62. imagePath = osp.join(osp.dirname(filename), data['imagePath'])
  63. imageData = self.load_image_file(imagePath)
  64. flags = data.get('flags')
  65. imagePath = data['imagePath']
  66. self._check_image_height_and_width(
  67. base64.b64encode(imageData).decode('utf-8'),
  68. data.get('imageHeight'),
  69. data.get('imageWidth'),
  70. )
  71. lineColor = data['lineColor']
  72. fillColor = data['fillColor']
  73. shapes = (
  74. (
  75. s['label'],
  76. s['points'],
  77. s['line_color'],
  78. s['fill_color'],
  79. s.get('shape_type', 'polygon'),
  80. s.get('flags', {}),
  81. )
  82. for s in data['shapes']
  83. )
  84. except Exception as e:
  85. raise LabelFileError(e)
  86. otherData = {}
  87. for key, value in data.items():
  88. if key not in keys:
  89. otherData[key] = value
  90. # Only replace data after everything is loaded.
  91. self.flags = flags
  92. self.shapes = shapes
  93. self.imagePath = imagePath
  94. self.imageData = imageData
  95. self.lineColor = lineColor
  96. self.fillColor = fillColor
  97. self.filename = filename
  98. self.otherData = otherData
  99. @staticmethod
  100. def _check_image_height_and_width(imageData, imageHeight, imageWidth):
  101. img_arr = utils.img_b64_to_arr(imageData)
  102. if imageHeight is not None and img_arr.shape[0] != imageHeight:
  103. logger.error(
  104. 'imageHeight does not match with imageData or imagePath, '
  105. 'so getting imageHeight from actual image.'
  106. )
  107. imageHeight = img_arr.shape[0]
  108. if imageWidth is not None and img_arr.shape[1] != imageWidth:
  109. logger.error(
  110. 'imageWidth does not match with imageData or imagePath, '
  111. 'so getting imageWidth from actual image.'
  112. )
  113. imageWidth = img_arr.shape[1]
  114. return imageHeight, imageWidth
  115. def save(
  116. self,
  117. filename,
  118. shapes,
  119. imagePath,
  120. imageHeight,
  121. imageWidth,
  122. imageData=None,
  123. lineColor=None,
  124. fillColor=None,
  125. otherData=None,
  126. flags=None,
  127. ):
  128. if imageData is not None:
  129. imageData = base64.b64encode(imageData).decode('utf-8')
  130. imageHeight, imageWidth = self._check_image_height_and_width(
  131. imageData, imageHeight, imageWidth
  132. )
  133. if otherData is None:
  134. otherData = {}
  135. if flags is None:
  136. flags = {}
  137. data = dict(
  138. version=__version__,
  139. flags=flags,
  140. shapes=shapes,
  141. lineColor=lineColor,
  142. fillColor=fillColor,
  143. imagePath=imagePath,
  144. imageData=imageData,
  145. imageHeight=imageHeight,
  146. imageWidth=imageWidth,
  147. )
  148. for key, value in otherData.items():
  149. data[key] = value
  150. try:
  151. with open(filename, 'wb' if PY2 else 'w') as f:
  152. json.dump(data, f, ensure_ascii=False, indent=2)
  153. self.filename = filename
  154. except Exception as e:
  155. raise LabelFileError(e)
  156. @staticmethod
  157. def is_label_file(filename):
  158. return osp.splitext(filename)[1].lower() == LabelFile.suffix