label_file.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import base64
  2. import contextlib
  3. import io
  4. import json
  5. import os.path as osp
  6. import PIL.Image
  7. from labelme import __version__
  8. from labelme.logger import logger
  9. from labelme import PY2
  10. from labelme import QT4
  11. from labelme import utils
  12. PIL.Image.MAX_IMAGE_PIXELS = None
  13. @contextlib.contextmanager
  14. def open(name, mode):
  15. assert mode in ["r", "w"]
  16. if PY2:
  17. mode += "b"
  18. encoding = None
  19. else:
  20. encoding = "utf-8"
  21. yield io.open(name, mode, encoding=encoding)
  22. return
  23. class LabelFileError(Exception):
  24. pass
  25. class LabelFile(object):
  26. suffix = ".json"
  27. def __init__(self, filename=None):
  28. self.shapes = []
  29. self.imagePath = None
  30. self.imageData = None
  31. if filename is not None:
  32. self.load(filename)
  33. self.filename = filename
  34. @staticmethod
  35. def load_image_file(filename):
  36. try:
  37. image_pil = PIL.Image.open(filename)
  38. except IOError:
  39. logger.error("Failed opening image file: {}".format(filename))
  40. return
  41. # apply orientation to image according to exif
  42. image_pil = utils.apply_exif_orientation(image_pil)
  43. with io.BytesIO() as f:
  44. ext = osp.splitext(filename)[1].lower()
  45. if PY2 and QT4:
  46. format = "PNG"
  47. elif ext in [".jpg", ".jpeg"]:
  48. format = "JPEG"
  49. else:
  50. format = "PNG"
  51. image_pil.save(f, format=format)
  52. f.seek(0)
  53. return f.read()
  54. def load(self, filename):
  55. keys = [
  56. "version",
  57. "imageData",
  58. "imagePath",
  59. "shapes", # polygonal annotations
  60. "flags", # image level flags
  61. "imageHeight",
  62. "imageWidth",
  63. ]
  64. shape_keys = [
  65. "label",
  66. "points",
  67. "group_id",
  68. "shape_type",
  69. "flags",
  70. "content",
  71. ]
  72. try:
  73. with open(filename, "r") as f:
  74. data = json.load(f)
  75. version = data.get("version")
  76. if version is None:
  77. logger.warning(
  78. "Loading JSON file ({}) of unknown version".format(
  79. filename
  80. )
  81. )
  82. elif version.split(".")[0] != __version__.split(".")[0]:
  83. logger.warning(
  84. "This JSON file ({}) may be incompatible with "
  85. "current labelme. version in file: {}, "
  86. "current version: {}".format(
  87. filename, version, __version__
  88. )
  89. )
  90. if data["imageData"] is not None:
  91. imageData = base64.b64decode(data["imageData"])
  92. if PY2 and QT4:
  93. imageData = utils.img_data_to_png_data(imageData)
  94. else:
  95. # relative path from label file to relative path from cwd
  96. imagePath = osp.join(osp.dirname(filename), data["imagePath"])
  97. imageData = self.load_image_file(imagePath)
  98. flags = data.get("flags") or {}
  99. imagePath = data["imagePath"]
  100. self._check_image_height_and_width(
  101. base64.b64encode(imageData).decode("utf-8"),
  102. data.get("imageHeight"),
  103. data.get("imageWidth"),
  104. )
  105. shapes = [
  106. dict(
  107. label=s["label"],
  108. points=s["points"],
  109. shape_type=s.get("shape_type", "polygon"),
  110. flags=s.get("flags", {}),
  111. content=s.get("content"),
  112. group_id=s.get("group_id"),
  113. other_data={
  114. k: v for k, v in s.items() if k not in shape_keys
  115. },
  116. )
  117. for s in data["shapes"]
  118. ]
  119. except Exception as e:
  120. raise LabelFileError(e)
  121. otherData = {}
  122. for key, value in data.items():
  123. if key not in keys:
  124. otherData[key] = value
  125. # Only replace data after everything is loaded.
  126. self.flags = flags
  127. self.shapes = shapes
  128. self.imagePath = imagePath
  129. self.imageData = imageData
  130. self.filename = filename
  131. self.otherData = otherData
  132. @staticmethod
  133. def _check_image_height_and_width(imageData, imageHeight, imageWidth):
  134. img_arr = utils.img_b64_to_arr(imageData)
  135. if imageHeight is not None and img_arr.shape[0] != imageHeight:
  136. logger.error(
  137. "imageHeight does not match with imageData or imagePath, "
  138. "so getting imageHeight from actual image."
  139. )
  140. imageHeight = img_arr.shape[0]
  141. if imageWidth is not None and img_arr.shape[1] != imageWidth:
  142. logger.error(
  143. "imageWidth does not match with imageData or imagePath, "
  144. "so getting imageWidth from actual image."
  145. )
  146. imageWidth = img_arr.shape[1]
  147. return imageHeight, imageWidth
  148. def save(
  149. self,
  150. filename,
  151. shapes,
  152. imagePath,
  153. imageHeight,
  154. imageWidth,
  155. imageData=None,
  156. otherData=None,
  157. flags=None,
  158. ):
  159. if imageData is not None:
  160. imageData = base64.b64encode(imageData).decode("utf-8")
  161. imageHeight, imageWidth = self._check_image_height_and_width(
  162. imageData, imageHeight, imageWidth
  163. )
  164. if otherData is None:
  165. otherData = {}
  166. if flags is None:
  167. flags = {}
  168. data = dict(
  169. version=__version__,
  170. flags=flags,
  171. shapes=shapes,
  172. imagePath=imagePath,
  173. imageData=imageData,
  174. imageHeight=imageHeight,
  175. imageWidth=imageWidth,
  176. )
  177. for key, value in otherData.items():
  178. assert key not in data
  179. data[key] = value
  180. try:
  181. with open(filename, "w") as f:
  182. json.dump(data, f, ensure_ascii=False, indent=2)
  183. self.filename = filename
  184. except Exception as e:
  185. raise LabelFileError(e)
  186. @staticmethod
  187. def is_label_file(filename):
  188. return osp.splitext(filename)[1].lower() == LabelFile.suffix