label_file.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import base64
  2. import contextlib
  3. import io
  4. import json
  5. import os.path as osp
  6. import PIL.Image
  7. from labelme import PY2
  8. from labelme import QT4
  9. from labelme import __version__
  10. from labelme import utils
  11. from labelme.logger import logger
  12. PIL.Image.MAX_IMAGE_PIXELS = None
  13. @contextlib.contextmanager
  14. def open(name, mode):
  15. assert mode in ["r", "w"]
  16. if PY2:
  17. mode += "b"
  18. encoding = None
  19. else:
  20. encoding = "utf-8"
  21. yield io.open(name, mode, encoding=encoding)
  22. return
  23. class LabelFileError(Exception):
  24. pass
  25. class LabelFile(object):
  26. suffix = ".json"
  27. def __init__(self, filename=None):
  28. self.shapes = []
  29. self.imagePath = None
  30. self.imageData = None
  31. if filename is not None:
  32. self.load(filename)
  33. self.filename = filename
  34. @staticmethod
  35. def load_image_file(filename):
  36. try:
  37. image_pil = PIL.Image.open(filename)
  38. except IOError:
  39. logger.error("Failed opening image file: {}".format(filename))
  40. return
  41. # apply orientation to image according to exif
  42. image_pil = utils.apply_exif_orientation(image_pil)
  43. with io.BytesIO() as f:
  44. ext = osp.splitext(filename)[1].lower()
  45. if PY2 and QT4:
  46. format = "PNG"
  47. elif ext in [".jpg", ".jpeg"]:
  48. format = "JPEG"
  49. else:
  50. format = "PNG"
  51. image_pil.save(f, format=format)
  52. f.seek(0)
  53. return f.read()
  54. def load(self, filename):
  55. keys = [
  56. "version",
  57. "imageData",
  58. "imagePath",
  59. "shapes", # polygonal annotations
  60. "flags", # image level flags
  61. "imageHeight",
  62. "imageWidth",
  63. ]
  64. shape_keys = [
  65. "label",
  66. "points",
  67. "group_id",
  68. "shape_type",
  69. "flags",
  70. "description",
  71. "mask",
  72. ]
  73. try:
  74. with open(filename, "r") as f:
  75. data = json.load(f)
  76. if data["imageData"] is not None:
  77. imageData = base64.b64decode(data["imageData"])
  78. if PY2 and QT4:
  79. imageData = utils.img_data_to_png_data(imageData)
  80. else:
  81. # relative path from label file to relative path from cwd
  82. imagePath = osp.join(osp.dirname(filename), data["imagePath"])
  83. imageData = self.load_image_file(imagePath)
  84. flags = data.get("flags") or {}
  85. imagePath = data["imagePath"]
  86. self._check_image_height_and_width(
  87. base64.b64encode(imageData).decode("utf-8"),
  88. data.get("imageHeight"),
  89. data.get("imageWidth"),
  90. )
  91. shapes = [
  92. dict(
  93. label=s["label"],
  94. points=s["points"],
  95. shape_type=s.get("shape_type", "polygon"),
  96. flags=s.get("flags", {}),
  97. description=s.get("description"),
  98. group_id=s.get("group_id"),
  99. mask=utils.img_b64_to_arr(s["mask"]).astype(bool)
  100. if s.get("mask")
  101. else None,
  102. other_data={k: v for k, v in s.items() if k not in shape_keys},
  103. )
  104. for s in data["shapes"]
  105. ]
  106. except Exception as e:
  107. raise LabelFileError(e)
  108. otherData = {}
  109. for key, value in data.items():
  110. if key not in keys:
  111. otherData[key] = value
  112. # Only replace data after everything is loaded.
  113. self.flags = flags
  114. self.shapes = shapes
  115. self.imagePath = imagePath
  116. self.imageData = imageData
  117. self.filename = filename
  118. self.otherData = otherData
  119. @staticmethod
  120. def _check_image_height_and_width(imageData, imageHeight, imageWidth):
  121. img_arr = utils.img_b64_to_arr(imageData)
  122. if imageHeight is not None and img_arr.shape[0] != imageHeight:
  123. logger.error(
  124. "imageHeight does not match with imageData or imagePath, "
  125. "so getting imageHeight from actual image."
  126. )
  127. imageHeight = img_arr.shape[0]
  128. if imageWidth is not None and img_arr.shape[1] != imageWidth:
  129. logger.error(
  130. "imageWidth does not match with imageData or imagePath, "
  131. "so getting imageWidth from actual image."
  132. )
  133. imageWidth = img_arr.shape[1]
  134. return imageHeight, imageWidth
  135. def save(
  136. self,
  137. filename,
  138. shapes,
  139. imagePath,
  140. imageHeight,
  141. imageWidth,
  142. imageData=None,
  143. otherData=None,
  144. flags=None,
  145. ):
  146. if imageData is not None:
  147. imageData = base64.b64encode(imageData).decode("utf-8")
  148. imageHeight, imageWidth = self._check_image_height_and_width(
  149. imageData, imageHeight, imageWidth
  150. )
  151. if otherData is None:
  152. otherData = {}
  153. if flags is None:
  154. flags = {}
  155. data = dict(
  156. version=__version__,
  157. flags=flags,
  158. shapes=shapes,
  159. imagePath=imagePath,
  160. imageData=imageData,
  161. imageHeight=imageHeight,
  162. imageWidth=imageWidth,
  163. )
  164. for key, value in otherData.items():
  165. assert key not in data
  166. data[key] = value
  167. try:
  168. with open(filename, "w") as f:
  169. json.dump(data, f, ensure_ascii=False, indent=2)
  170. self.filename = filename
  171. except Exception as e:
  172. raise LabelFileError(e)
  173. @staticmethod
  174. def is_label_file(filename):
  175. return osp.splitext(filename)[1].lower() == LabelFile.suffix