123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- import base64
- import json
- import os.path
- from . import logger
- from . import PY2
- from . import utils
- from ._version import __version__
- class LabelFileError(Exception):
- pass
- class LabelFile(object):
- suffix = '.json'
- def __init__(self, filename=None):
- self.shapes = ()
- self.imagePath = None
- self.imageData = None
- if filename is not None:
- self.load(filename)
- self.filename = filename
- def load(self, filename):
- keys = [
- 'imageData',
- 'imagePath',
- 'lineColor',
- 'fillColor',
- 'shapes', # polygonal annotations
- 'flags', # image level flags
- 'imageHeight',
- 'imageWidth',
- ]
- try:
- with open(filename, 'rb' if PY2 else 'r') as f:
- data = json.load(f)
- if data['imageData'] is not None:
- imageData = base64.b64decode(data['imageData'])
- else:
- # relative path from label file to relative path from cwd
- imagePath = os.path.join(os.path.dirname(filename),
- data['imagePath'])
- with open(imagePath, 'rb') as f:
- imageData = f.read()
- flags = data.get('flags')
- imagePath = data['imagePath']
- self._check_image_height_and_width(
- base64.b64encode(imageData).decode('utf-8'),
- data.get('imageHeight'),
- data.get('imageWidth'),
- )
- lineColor = data['lineColor']
- fillColor = data['fillColor']
- shapes = (
- (
- s['label'],
- s['points'],
- s['line_color'],
- s['fill_color'],
- s.get('shape_type', 'polygon'),
- )
- for s in data['shapes']
- )
- except Exception as e:
- raise LabelFileError(e)
- otherData = {}
- for key, value in data.items():
- if key not in keys:
- otherData[key] = value
- # Only replace data after everything is loaded.
- self.flags = flags
- self.shapes = shapes
- self.imagePath = imagePath
- self.imageData = imageData
- self.lineColor = lineColor
- self.fillColor = fillColor
- self.filename = filename
- self.otherData = otherData
- @staticmethod
- def _check_image_height_and_width(imageData, imageHeight, imageWidth):
- img_arr = utils.img_b64_to_arr(imageData)
- if imageHeight is not None and img_arr.shape[0] != imageHeight:
- logger.error(
- 'imageHeight does not match with imageData or imagePath, '
- 'so getting imageHeight from actual image.'
- )
- imageHeight = img_arr.shape[0]
- if imageWidth is not None and img_arr.shape[1] != imageWidth:
- logger.error(
- 'imageWidth does not match with imageData or imagePath, '
- 'so getting imageWidth from actual image.'
- )
- imageWidth = img_arr.shape[1]
- return imageHeight, imageWidth
- def save(
- self,
- filename,
- shapes,
- imagePath,
- imageHeight,
- imageWidth,
- imageData=None,
- lineColor=None,
- fillColor=None,
- otherData=None,
- flags=None,
- ):
- if imageData is not None:
- imageData = base64.b64encode(imageData).decode('utf-8')
- imageHeight, imageWidth = self._check_image_height_and_width(
- imageData, imageHeight, imageWidth
- )
- if otherData is None:
- otherData = {}
- if flags is None:
- flags = []
- data = dict(
- version=__version__,
- flags=flags,
- shapes=shapes,
- lineColor=lineColor,
- fillColor=fillColor,
- imagePath=imagePath,
- imageData=imageData,
- imageHeight=imageHeight,
- imageWidth=imageWidth,
- )
- for key, value in otherData.items():
- data[key] = value
- try:
- with open(filename, 'wb' if PY2 else 'w') as f:
- json.dump(data, f, ensure_ascii=False, indent=2)
- self.filename = filename
- except Exception as e:
- raise LabelFileError(e)
- @staticmethod
- def isLabelFile(filename):
- return os.path.splitext(filename)[1].lower() == LabelFile.suffix
|