label_file.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import base64
  2. import io
  3. import json
  4. import os.path as osp
  5. import PIL.Image
  6. from labelme import __version__
  7. from labelme.logger import logger
  8. from labelme import PY2
  9. from labelme import QT4
  10. from labelme import utils
  11. PIL.Image.MAX_IMAGE_PIXELS = None
  12. class LabelFileError(Exception):
  13. pass
  14. class LabelFile(object):
  15. suffix = ".json"
  16. def __init__(self, filename=None):
  17. self.shapes = []
  18. self.imagePath = None
  19. self.imageData = None
  20. if filename is not None:
  21. self.load(filename)
  22. self.filename = filename
  23. @staticmethod
  24. def load_image_file(filename):
  25. try:
  26. image_pil = PIL.Image.open(filename)
  27. except IOError:
  28. logger.error("Failed opening image file: {}".format(filename))
  29. return
  30. # apply orientation to image according to exif
  31. image_pil = utils.apply_exif_orientation(image_pil)
  32. with io.BytesIO() as f:
  33. ext = osp.splitext(filename)[1].lower()
  34. if PY2 and QT4:
  35. format = "PNG"
  36. elif ext in [".jpg", ".jpeg"]:
  37. format = "JPEG"
  38. else:
  39. format = "PNG"
  40. image_pil.save(f, format=format)
  41. f.seek(0)
  42. return f.read()
  43. def load(self, filename):
  44. keys = [
  45. "version",
  46. "imageData",
  47. "imagePath",
  48. "shapes", # polygonal annotations
  49. "flags", # image level flags
  50. "imageHeight",
  51. "imageWidth",
  52. ]
  53. shape_keys = [
  54. "label",
  55. "points",
  56. "group_id",
  57. "shape_type",
  58. "flags",
  59. ]
  60. try:
  61. with open(filename, "rb" if PY2 else "r") as f:
  62. data = json.load(f)
  63. version = data.get("version")
  64. if version is None:
  65. logger.warn(
  66. "Loading JSON file ({}) of unknown version".format(
  67. filename
  68. )
  69. )
  70. elif version.split(".")[0] != __version__.split(".")[0]:
  71. logger.warn(
  72. "This JSON file ({}) may be incompatible with "
  73. "current labelme. version in file: {}, "
  74. "current version: {}".format(
  75. filename, version, __version__
  76. )
  77. )
  78. if data["imageData"] is not None:
  79. imageData = base64.b64decode(data["imageData"])
  80. if PY2 and QT4:
  81. imageData = utils.img_data_to_png_data(imageData)
  82. else:
  83. # relative path from label file to relative path from cwd
  84. imagePath = osp.join(osp.dirname(filename), data["imagePath"])
  85. imageData = self.load_image_file(imagePath)
  86. flags = data.get("flags") or {}
  87. imagePath = data["imagePath"]
  88. self._check_image_height_and_width(
  89. base64.b64encode(imageData).decode("utf-8"),
  90. data.get("imageHeight"),
  91. data.get("imageWidth"),
  92. )
  93. shapes = [
  94. dict(
  95. label=s["label"],
  96. points=s["points"],
  97. shape_type=s.get("shape_type", "polygon"),
  98. flags=s.get("flags", {}),
  99. group_id=s.get("group_id"),
  100. other_data={
  101. k: v for k, v in s.items() if k not in shape_keys
  102. },
  103. )
  104. for s in data["shapes"]
  105. ]
  106. except Exception as e:
  107. raise LabelFileError(e)
  108. otherData = {}
  109. for key, value in data.items():
  110. if key not in keys:
  111. otherData[key] = value
  112. # Only replace data after everything is loaded.
  113. self.flags = flags
  114. self.shapes = shapes
  115. self.imagePath = imagePath
  116. self.imageData = imageData
  117. self.filename = filename
  118. self.otherData = otherData
  119. @staticmethod
  120. def _check_image_height_and_width(imageData, imageHeight, imageWidth):
  121. img_arr = utils.img_b64_to_arr(imageData)
  122. if imageHeight is not None and img_arr.shape[0] != imageHeight:
  123. logger.error(
  124. "imageHeight does not match with imageData or imagePath, "
  125. "so getting imageHeight from actual image."
  126. )
  127. imageHeight = img_arr.shape[0]
  128. if imageWidth is not None and img_arr.shape[1] != imageWidth:
  129. logger.error(
  130. "imageWidth does not match with imageData or imagePath, "
  131. "so getting imageWidth from actual image."
  132. )
  133. imageWidth = img_arr.shape[1]
  134. return imageHeight, imageWidth
  135. def save(
  136. self,
  137. filename,
  138. shapes,
  139. imagePath,
  140. imageHeight,
  141. imageWidth,
  142. imageData=None,
  143. otherData=None,
  144. flags=None,
  145. ):
  146. if imageData is not None:
  147. imageData = base64.b64encode(imageData).decode("utf-8")
  148. imageHeight, imageWidth = self._check_image_height_and_width(
  149. imageData, imageHeight, imageWidth
  150. )
  151. if otherData is None:
  152. otherData = {}
  153. if flags is None:
  154. flags = {}
  155. data = dict(
  156. version=__version__,
  157. flags=flags,
  158. shapes=shapes,
  159. imagePath=imagePath,
  160. imageData=imageData,
  161. imageHeight=imageHeight,
  162. imageWidth=imageWidth,
  163. )
  164. for key, value in otherData.items():
  165. assert key not in data
  166. data[key] = value
  167. try:
  168. with open(filename, "wb" if PY2 else "w") as f:
  169. json.dump(data, f, ensure_ascii=False, indent=2)
  170. self.filename = filename
  171. except Exception as e:
  172. raise LabelFileError(e)
  173. @staticmethod
  174. def is_label_file(filename):
  175. return osp.splitext(filename)[1].lower() == LabelFile.suffix