label_file.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import base64
  2. import io
  3. import json
  4. import os.path as osp
  5. import PIL.Image
  6. from labelme import __version__
  7. from labelme.logger import logger
  8. from labelme import PY2
  9. from labelme import QT4
  10. from labelme import utils
  11. PIL.Image.MAX_IMAGE_PIXELS = None
  12. class LabelFileError(Exception):
  13. pass
  14. class LabelFile(object):
  15. suffix = '.json'
  16. def __init__(self, filename=None):
  17. self.shapes = []
  18. self.imagePath = None
  19. self.imageData = None
  20. if filename is not None:
  21. self.load(filename)
  22. self.filename = filename
  23. @staticmethod
  24. def load_image_file(filename):
  25. try:
  26. image_pil = PIL.Image.open(filename)
  27. except IOError:
  28. logger.error('Failed opening image file: {}'.format(filename))
  29. return
  30. # apply orientation to image according to exif
  31. image_pil = utils.apply_exif_orientation(image_pil)
  32. with io.BytesIO() as f:
  33. ext = osp.splitext(filename)[1].lower()
  34. if PY2 and QT4:
  35. format = 'PNG'
  36. elif ext in ['.jpg', '.jpeg']:
  37. format = 'JPEG'
  38. else:
  39. format = 'PNG'
  40. image_pil.save(f, format=format)
  41. f.seek(0)
  42. return f.read()
  43. def load(self, filename):
  44. keys = [
  45. 'version',
  46. 'imageData',
  47. 'imagePath',
  48. 'shapes', # polygonal annotations
  49. 'flags', # image level flags
  50. 'imageHeight',
  51. 'imageWidth',
  52. ]
  53. shape_keys = [
  54. 'label',
  55. 'points',
  56. 'group_id',
  57. 'shape_type',
  58. 'flags',
  59. ]
  60. try:
  61. with open(filename, 'rb' if PY2 else 'r') as f:
  62. data = json.load(f)
  63. version = data.get('version')
  64. if version is None:
  65. logger.warn(
  66. 'Loading JSON file ({}) of unknown version'
  67. .format(filename)
  68. )
  69. elif version.split('.')[0] != __version__.split('.')[0]:
  70. logger.warn(
  71. 'This JSON file ({}) may be incompatible with '
  72. 'current labelme. version in file: {}, '
  73. 'current version: {}'.format(
  74. filename, version, __version__
  75. )
  76. )
  77. if data['imageData'] is not None:
  78. imageData = base64.b64decode(data['imageData'])
  79. if PY2 and QT4:
  80. imageData = utils.img_data_to_png_data(imageData)
  81. else:
  82. # relative path from label file to relative path from cwd
  83. imagePath = osp.join(osp.dirname(filename), data['imagePath'])
  84. imageData = self.load_image_file(imagePath)
  85. flags = data.get('flags') or {}
  86. imagePath = data['imagePath']
  87. self._check_image_height_and_width(
  88. base64.b64encode(imageData).decode('utf-8'),
  89. data.get('imageHeight'),
  90. data.get('imageWidth'),
  91. )
  92. shapes = [
  93. dict(
  94. label=s['label'],
  95. points=s['points'],
  96. shape_type=s.get('shape_type', 'polygon'),
  97. flags=s.get('flags', {}),
  98. group_id=s.get('group_id'),
  99. other_data={
  100. k: v for k, v in s.items() if k not in shape_keys
  101. }
  102. )
  103. for s in data['shapes']
  104. ]
  105. except Exception as e:
  106. raise LabelFileError(e)
  107. otherData = {}
  108. for key, value in data.items():
  109. if key not in keys:
  110. otherData[key] = value
  111. # Only replace data after everything is loaded.
  112. self.flags = flags
  113. self.shapes = shapes
  114. self.imagePath = imagePath
  115. self.imageData = imageData
  116. self.filename = filename
  117. self.otherData = otherData
  118. @staticmethod
  119. def _check_image_height_and_width(imageData, imageHeight, imageWidth):
  120. img_arr = utils.img_b64_to_arr(imageData)
  121. if imageHeight is not None and img_arr.shape[0] != imageHeight:
  122. logger.error(
  123. 'imageHeight does not match with imageData or imagePath, '
  124. 'so getting imageHeight from actual image.'
  125. )
  126. imageHeight = img_arr.shape[0]
  127. if imageWidth is not None and img_arr.shape[1] != imageWidth:
  128. logger.error(
  129. 'imageWidth does not match with imageData or imagePath, '
  130. 'so getting imageWidth from actual image.'
  131. )
  132. imageWidth = img_arr.shape[1]
  133. return imageHeight, imageWidth
  134. def save(
  135. self,
  136. filename,
  137. shapes,
  138. imagePath,
  139. imageHeight,
  140. imageWidth,
  141. imageData=None,
  142. otherData=None,
  143. flags=None,
  144. ):
  145. if imageData is not None:
  146. imageData = base64.b64encode(imageData).decode('utf-8')
  147. imageHeight, imageWidth = self._check_image_height_and_width(
  148. imageData, imageHeight, imageWidth
  149. )
  150. if otherData is None:
  151. otherData = {}
  152. if flags is None:
  153. flags = {}
  154. data = dict(
  155. version=__version__,
  156. flags=flags,
  157. shapes=shapes,
  158. imagePath=imagePath,
  159. imageData=imageData,
  160. imageHeight=imageHeight,
  161. imageWidth=imageWidth,
  162. )
  163. for key, value in otherData.items():
  164. assert key not in data
  165. data[key] = value
  166. try:
  167. with open(filename, 'wb' if PY2 else 'w') as f:
  168. json.dump(data, f, ensure_ascii=False, indent=2)
  169. self.filename = filename
  170. except Exception as e:
  171. raise LabelFileError(e)
  172. @staticmethod
  173. def is_label_file(filename):
  174. return osp.splitext(filename)[1].lower() == LabelFile.suffix