#!/usr/bin/env python import argparse import json import base64 from cStringIO import StringIO import PIL.Image import PIL.ImageDraw import matplotlib.pyplot as plt import numpy as np from skimage.color import label2rgb def labelcolormap(N=256): def bitget(byteval, idx): return ((byteval & (1 << idx)) != 0) cmap = np.zeros((N, 3)) for i in xrange(0, N): id = i r, g, b = 0, 0, 0 for j in xrange(0, 8): r = np.bitwise_or(r, (bitget(id, 0) << 7-j)) g = np.bitwise_or(g, (bitget(id, 1) << 7-j)) b = np.bitwise_or(b, (bitget(id, 2) << 7-j)) id = (id >> 3) cmap[i, 0] = r cmap[i, 1] = g cmap[i, 2] = b cmap = cmap.astype(np.float32) / 255 return cmap def main(): parser = argparse.ArgumentParser() parser.add_argument('json_file') args = parser.parse_args() json_file = args.json_file data = json.load(open(json_file)) # string -> numpy.ndarray f = StringIO() f.write(base64.b64decode(data['imageData'])) img = np.array(PIL.Image.open(f)) target_names = {'background': 0} label = np.zeros(img.shape[:2], dtype=np.int32) for shape in data['shapes']: # get label value label_value = target_names.get(shape['label']) if label_value is None: label_value = len(target_names) target_names[shape['label']] = label_value # polygon -> mask mask = np.zeros(img.shape[:2], dtype=np.uint8) mask = PIL.Image.fromarray(mask) xy = map(tuple, shape['points']) PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1) mask = np.array(mask) # fill label value label[mask == 1] = label_value cmap = labelcolormap(len(target_names)) label_viz = label2rgb(label, img, colors=cmap) label_viz[label == target_names['background']] = 0 plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) plt.margins(0, 0) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.subplot(121) plt.axis('off') plt.imshow(img) plt.subplot(122) plt.axis('off') plt.imshow(label_viz) # plot target_names legend plt_handlers = [] plt_titles = [] for l_name, l_value in target_names.items(): fc = cmap[l_value] p = plt.Rectangle((0, 0), 1, 1, fc=fc) plt_handlers.append(p) plt_titles.append(l_name) plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5) plt.show() if __name__ == '__main__': main()