utils.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import base64
  2. import io
  3. import os.path as osp
  4. import warnings
  5. import numpy as np
  6. import PIL.Image
  7. import PIL.ImageDraw
  8. from labelme import logger
  9. def label_colormap(N=256):
  10. def bitget(byteval, idx):
  11. return ((byteval & (1 << idx)) != 0)
  12. cmap = np.zeros((N, 3))
  13. for i in range(0, N):
  14. id = i
  15. r, g, b = 0, 0, 0
  16. for j in range(0, 8):
  17. r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
  18. g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
  19. b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
  20. id = (id >> 3)
  21. cmap[i, 0] = r
  22. cmap[i, 1] = g
  23. cmap[i, 2] = b
  24. cmap = cmap.astype(np.float32) / 255
  25. return cmap
  26. def labelcolormap(N=256):
  27. warnings.warn('labelcolormap is deprecated. Please use label_colormap.')
  28. return label_colormap(N=N)
  29. # similar function as skimage.color.label2rgb
  30. def label2rgb(lbl, img=None, n_labels=None, alpha=0.3, thresh_suppress=0):
  31. if n_labels is None:
  32. n_labels = len(np.unique(lbl))
  33. cmap = label_colormap(n_labels)
  34. cmap = (cmap * 255).astype(np.uint8)
  35. lbl_viz = cmap[lbl]
  36. lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled
  37. if img is not None:
  38. img_gray = PIL.Image.fromarray(img).convert('LA')
  39. img_gray = np.asarray(img_gray.convert('RGB'))
  40. # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  41. # img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
  42. lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
  43. lbl_viz = lbl_viz.astype(np.uint8)
  44. return lbl_viz
  45. def img_b64_to_array(img_b64):
  46. warnings.warn('img_ba64_to_array is deprecated. '
  47. 'Please use img_b64_to_arr.')
  48. return img_b64_to_arr(img_b64)
  49. def img_b64_to_arr(img_b64):
  50. f = io.BytesIO()
  51. f.write(base64.b64decode(img_b64))
  52. img_arr = np.array(PIL.Image.open(f))
  53. return img_arr
  54. def img_arr_to_b64(img_arr):
  55. img_pil = PIL.Image.fromarray(img_arr)
  56. f = io.BytesIO()
  57. img_pil.save(f, format='PNG')
  58. img_bin = f.getvalue()
  59. img_b64 = base64.encodestring(img_bin)
  60. return img_b64
  61. def polygons_to_mask(img_shape, polygons):
  62. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  63. mask = PIL.Image.fromarray(mask)
  64. xy = list(map(tuple, polygons))
  65. PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)
  66. mask = np.array(mask, dtype=bool)
  67. return mask
  68. def draw_label(label, img=None, label_names=None, colormap=None):
  69. import matplotlib.pyplot as plt
  70. backend_org = plt.rcParams['backend']
  71. plt.switch_backend('agg')
  72. plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
  73. wspace=0, hspace=0)
  74. plt.margins(0, 0)
  75. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  76. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  77. if label_names is None:
  78. label_names = [str(l) for l in range(label.max() + 1)]
  79. if colormap is None:
  80. colormap = label_colormap(len(label_names))
  81. label_viz = label2rgb(label, img, n_labels=len(label_names))
  82. plt.imshow(label_viz)
  83. plt.axis('off')
  84. plt_handlers = []
  85. plt_titles = []
  86. for label_value, label_name in enumerate(label_names):
  87. if label_value not in label:
  88. continue
  89. if label_name.startswith('_'):
  90. continue
  91. fc = colormap[label_value]
  92. p = plt.Rectangle((0, 0), 1, 1, fc=fc)
  93. plt_handlers.append(p)
  94. plt_titles.append('{value}: {name}'
  95. .format(value=label_value, name=label_name))
  96. plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
  97. f = io.BytesIO()
  98. plt.savefig(f, bbox_inches='tight', pad_inches=0)
  99. plt.cla()
  100. plt.close()
  101. plt.switch_backend(backend_org)
  102. out_size = (label_viz.shape[1], label_viz.shape[0])
  103. out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
  104. out = np.asarray(out)
  105. return out
  106. def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'):
  107. assert type in ['class', 'instance']
  108. cls = np.zeros(img_shape[:2], dtype=np.int32)
  109. if type == 'instance':
  110. ins = np.zeros(img_shape[:2], dtype=np.int32)
  111. instance_names = ['_background_']
  112. for shape in shapes:
  113. polygons = shape['points']
  114. label = shape['label']
  115. if type == 'class':
  116. cls_name = label
  117. elif type == 'instance':
  118. cls_name = label.split('-')[0]
  119. if label not in instance_names:
  120. instance_names.append(label)
  121. ins_id = len(instance_names) - 1
  122. cls_id = label_name_to_value[cls_name]
  123. mask = polygons_to_mask(img_shape[:2], polygons)
  124. cls[mask] = cls_id
  125. if type == 'instance':
  126. ins[mask] = ins_id
  127. if type == 'instance':
  128. return cls, ins
  129. return cls
  130. def labelme_shapes_to_label(img_shape, shapes):
  131. warnings.warn('labelme_shapes_to_label is deprecated, so please use '
  132. 'shapes_to_label.')
  133. label_name_to_value = {'_background_': 0}
  134. for shape in shapes:
  135. label_name = shape['label']
  136. if label_name in label_name_to_value:
  137. label_value = label_name_to_value[label_name]
  138. else:
  139. label_value = len(label_name_to_value)
  140. label_name_to_value[label_name] = label_value
  141. lbl = shapes_to_label(img_shape, shapes, label_name_to_value)
  142. return lbl, label_name_to_value
  143. def lblsave(filename, lbl):
  144. if osp.splitext(filename)[1] != '.png':
  145. filename += '.png'
  146. # Assume label ranses [-1, 254] for int32,
  147. # and [0, 255] for uint8 as VOC.
  148. if lbl.min() >= -1 and lbl.max() < 255:
  149. lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
  150. colormap = label_colormap(255)
  151. lbl_pil.putpalette((colormap * 255).astype(np.uint8).flatten())
  152. lbl_pil.save(filename)
  153. else:
  154. logger.warn(
  155. '[%s] Cannot save the pixel-wise class label as PNG, '
  156. 'so please use the npy file.' % filename
  157. )