123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- import base64
- import io
- import warnings
- import numpy as np
- import PIL.Image
- import PIL.ImageDraw
- def label_colormap(N=256):
- def bitget(byteval, idx):
- return ((byteval & (1 << idx)) != 0)
- cmap = np.zeros((N, 3))
- for i in range(0, N):
- id = i
- r, g, b = 0, 0, 0
- for j in range(0, 8):
- r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
- g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
- b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
- id = (id >> 3)
- cmap[i, 0] = r
- cmap[i, 1] = g
- cmap[i, 2] = b
- cmap = cmap.astype(np.float32) / 255
- return cmap
- def labelcolormap(N=256):
- warnings.warn('labelcolormap is deprecated. Please use label_colormap.')
- return label_colormap(N=N)
- # similar function as skimage.color.label2rgb
- def label2rgb(lbl, img=None, n_labels=None, alpha=0.3, thresh_suppress=0):
- if n_labels is None:
- n_labels = len(np.unique(lbl))
- cmap = label_colormap(n_labels)
- cmap = (cmap * 255).astype(np.uint8)
- lbl_viz = cmap[lbl]
- lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled
- if img is not None:
- img_gray = PIL.Image.fromarray(img).convert('LA')
- img_gray = np.asarray(img_gray.convert('RGB'))
- # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
- # img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
- lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
- lbl_viz = lbl_viz.astype(np.uint8)
- return lbl_viz
- def img_b64_to_array(img_b64):
- warnings.warn('img_ba64_to_array is deprecated. '
- 'Please use img_b64_to_arr.')
- return img_b64_to_arr(img_b64)
- def img_b64_to_arr(img_b64):
- f = io.BytesIO()
- f.write(base64.b64decode(img_b64))
- img_arr = np.array(PIL.Image.open(f))
- return img_arr
- def img_arr_to_b64(img_arr):
- img_pil = PIL.Image.fromarray(img_arr)
- f = io.BytesIO()
- img_pil.save(f, format='PNG')
- img_bin = f.getvalue()
- img_b64 = base64.encodestring(img_bin)
- return img_b64
- def polygons_to_mask(img_shape, polygons):
- mask = np.zeros(img_shape[:2], dtype=np.uint8)
- mask = PIL.Image.fromarray(mask)
- xy = list(map(tuple, polygons))
- PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)
- mask = np.array(mask, dtype=bool)
- return mask
- def draw_label(label, img, label_names, colormap=None):
- import matplotlib.pyplot as plt
- backend_org = plt.rcParams['backend']
- plt.switch_backend('agg')
- plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
- wspace=0, hspace=0)
- plt.margins(0, 0)
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
- if colormap is None:
- colormap = label_colormap(len(label_names))
- label_viz = label2rgb(label, img, n_labels=len(label_names))
- plt.imshow(label_viz)
- plt.axis('off')
- plt_handlers = []
- plt_titles = []
- for label_value, label_name in enumerate(label_names):
- if label_value not in label:
- continue
- if label_name.startswith('_'):
- continue
- fc = colormap[label_value]
- p = plt.Rectangle((0, 0), 1, 1, fc=fc)
- plt_handlers.append(p)
- plt_titles.append(label_name)
- plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
- f = io.BytesIO()
- plt.savefig(f, bbox_inches='tight', pad_inches=0)
- plt.cla()
- plt.close()
- plt.switch_backend(backend_org)
- out_size = (img.shape[1], img.shape[0])
- out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
- out = np.asarray(out)
- return out
- def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'):
- assert type in ['class', 'instance']
- cls = np.zeros(img_shape[:2], dtype=np.int32)
- if type == 'instance':
- ins = np.zeros(img_shape[:2], dtype=np.int32)
- instance_names = ['__background__']
- for shape in shapes:
- polygons = shape['points']
- label = shape['label']
- if type == 'class':
- cls_name = label
- elif type == 'instance':
- cls_name = label.split('-')[0]
- if label not in instance_names:
- instance_names.append(label)
- ins_id = len(instance_names) - 1
- cls_id = label_name_to_value[cls_name]
- mask = polygons_to_mask(img_shape[:2], polygons)
- cls[mask] = cls_id
- if type == 'instance':
- ins[mask] = ins_id
- if type == 'instance':
- return cls, ins
- return cls
- def labelme_shapes_to_label(img_shape, shapes):
- warnings.warn('labelme_shapes_to_label is deprecated, so please use '
- 'shapes_to_label.')
- label_name_to_value = {}
- for shape in shapes:
- label_name = shape['label']
- if label_name in label_name_to_value:
- label_value = label_name_to_value[label_name]
- else:
- label_value = len(label_name_to_value)
- label_name_to_value[label_name] = label_value
- lbl = shapes_to_label(img_shape, shapes, label_name_to_value)
- return lbl, label_name_to_value
|