123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- import base64
- import io
- import json
- import os.path as osp
- import PIL.Image
- from labelme import __version__
- from labelme.logger import logger
- from labelme import PY2
- from labelme import QT4
- from labelme import utils
- PIL.Image.MAX_IMAGE_PIXELS = None
- class LabelFileError(Exception):
- pass
- class LabelFile(object):
- suffix = ".json"
- def __init__(self, filename=None):
- self.shapes = []
- self.imagePath = None
- self.imageData = None
- if filename is not None:
- self.load(filename)
- self.filename = filename
- @staticmethod
- def load_image_file(filename):
- try:
- image_pil = PIL.Image.open(filename)
- except IOError:
- logger.error("Failed opening image file: {}".format(filename))
- return
- # apply orientation to image according to exif
- image_pil = utils.apply_exif_orientation(image_pil)
- with io.BytesIO() as f:
- ext = osp.splitext(filename)[1].lower()
- if PY2 and QT4:
- format = "PNG"
- elif ext in [".jpg", ".jpeg"]:
- format = "JPEG"
- else:
- format = "PNG"
- image_pil.save(f, format=format)
- f.seek(0)
- return f.read()
- def load(self, filename):
- keys = [
- "version",
- "imageData",
- "imagePath",
- "shapes", # polygonal annotations
- "flags", # image level flags
- "imageHeight",
- "imageWidth",
- ]
- shape_keys = [
- "label",
- "points",
- "group_id",
- "shape_type",
- "flags",
- ]
- try:
- with open(filename, "rb" if PY2 else "r") as f:
- data = json.load(f)
- version = data.get("version")
- if version is None:
- logger.warn(
- "Loading JSON file ({}) of unknown version".format(
- filename
- )
- )
- elif version.split(".")[0] != __version__.split(".")[0]:
- logger.warn(
- "This JSON file ({}) may be incompatible with "
- "current labelme. version in file: {}, "
- "current version: {}".format(
- filename, version, __version__
- )
- )
- if data["imageData"] is not None:
- imageData = base64.b64decode(data["imageData"])
- if PY2 and QT4:
- imageData = utils.img_data_to_png_data(imageData)
- else:
- # relative path from label file to relative path from cwd
- imagePath = osp.join(osp.dirname(filename), data["imagePath"])
- imageData = self.load_image_file(imagePath)
- flags = data.get("flags") or {}
- imagePath = data["imagePath"]
- self._check_image_height_and_width(
- base64.b64encode(imageData).decode("utf-8"),
- data.get("imageHeight"),
- data.get("imageWidth"),
- )
- shapes = [
- dict(
- label=s["label"],
- points=s["points"],
- shape_type=s.get("shape_type", "polygon"),
- flags=s.get("flags", {}),
- group_id=s.get("group_id"),
- other_data={
- k: v for k, v in s.items() if k not in shape_keys
- },
- )
- for s in data["shapes"]
- ]
- except Exception as e:
- raise LabelFileError(e)
- otherData = {}
- for key, value in data.items():
- if key not in keys:
- otherData[key] = value
- # Only replace data after everything is loaded.
- self.flags = flags
- self.shapes = shapes
- self.imagePath = imagePath
- self.imageData = imageData
- self.filename = filename
- self.otherData = otherData
- @staticmethod
- def _check_image_height_and_width(imageData, imageHeight, imageWidth):
- img_arr = utils.img_b64_to_arr(imageData)
- if imageHeight is not None and img_arr.shape[0] != imageHeight:
- logger.error(
- "imageHeight does not match with imageData or imagePath, "
- "so getting imageHeight from actual image."
- )
- imageHeight = img_arr.shape[0]
- if imageWidth is not None and img_arr.shape[1] != imageWidth:
- logger.error(
- "imageWidth does not match with imageData or imagePath, "
- "so getting imageWidth from actual image."
- )
- imageWidth = img_arr.shape[1]
- return imageHeight, imageWidth
- def save(
- self,
- filename,
- shapes,
- imagePath,
- imageHeight,
- imageWidth,
- imageData=None,
- otherData=None,
- flags=None,
- ):
- if imageData is not None:
- imageData = base64.b64encode(imageData).decode("utf-8")
- imageHeight, imageWidth = self._check_image_height_and_width(
- imageData, imageHeight, imageWidth
- )
- if otherData is None:
- otherData = {}
- if flags is None:
- flags = {}
- data = dict(
- version=__version__,
- flags=flags,
- shapes=shapes,
- imagePath=imagePath,
- imageData=imageData,
- imageHeight=imageHeight,
- imageWidth=imageWidth,
- )
- for key, value in otherData.items():
- assert key not in data
- data[key] = value
- try:
- with open(filename, "wb" if PY2 else "w") as f:
- json.dump(data, f, ensure_ascii=False, indent=2)
- self.filename = filename
- except Exception as e:
- raise LabelFileError(e)
- @staticmethod
- def is_label_file(filename):
- return osp.splitext(filename)[1].lower() == LabelFile.suffix
|