utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import base64
  2. try:
  3. import io
  4. except ImportError:
  5. import io as io
  6. import warnings
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import PIL.Image
  10. import PIL.ImageDraw
  11. def label_colormap(N=256):
  12. def bitget(byteval, idx):
  13. return ((byteval & (1 << idx)) != 0)
  14. cmap = np.zeros((N, 3))
  15. for i in range(0, N):
  16. id = i
  17. r, g, b = 0, 0, 0
  18. for j in range(0, 8):
  19. r = np.bitwise_or(r, (bitget(id, 0) << 7-j))
  20. g = np.bitwise_or(g, (bitget(id, 1) << 7-j))
  21. b = np.bitwise_or(b, (bitget(id, 2) << 7-j))
  22. id = (id >> 3)
  23. cmap[i, 0] = r
  24. cmap[i, 1] = g
  25. cmap[i, 2] = b
  26. cmap = cmap.astype(np.float32) / 255
  27. return cmap
  28. def labelcolormap(N=256):
  29. warnings.warn('labelcolormap is deprecated. Please use label_colormap.')
  30. return label_colormap(N=N)
  31. # similar function as skimage.color.label2rgb
  32. def label2rgb(lbl, img=None, n_labels=None, alpha=0.3, thresh_suppress=0):
  33. if n_labels is None:
  34. n_labels = len(np.unique(lbl))
  35. cmap = label_colormap(n_labels)
  36. cmap = (cmap * 255).astype(np.uint8)
  37. lbl_viz = cmap[lbl]
  38. lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled
  39. if img is not None:
  40. img_gray = PIL.Image.fromarray(img).convert('LA')
  41. img_gray = np.asarray(img_gray.convert('RGB'))
  42. # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  43. # img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
  44. lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
  45. lbl_viz = lbl_viz.astype(np.uint8)
  46. return lbl_viz
  47. def img_b64_to_array(img_b64):
  48. f = io.BytesIO()
  49. f.write(base64.b64decode(img_b64))
  50. img_arr = np.array(PIL.Image.open(f))
  51. return img_arr
  52. def polygons_to_mask(img_shape, polygons):
  53. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  54. mask = PIL.Image.fromarray(mask)
  55. xy = list(map(tuple, polygons))
  56. PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)
  57. mask = np.array(mask, dtype=bool)
  58. return mask
  59. def draw_label(label, img, label_names, colormap=None):
  60. backend_org = plt.rcParams['backend']
  61. plt.switch_backend('agg')
  62. plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
  63. wspace=0, hspace=0)
  64. plt.margins(0, 0)
  65. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  66. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  67. if colormap is None:
  68. colormap = label_colormap(len(label_names))
  69. label_viz = label2rgb(label, img, n_labels=len(label_names))
  70. plt.imshow(label_viz)
  71. plt.axis('off')
  72. plt_handlers = []
  73. plt_titles = []
  74. for label_value, label_name in enumerate(label_names):
  75. fc = colormap[label_value]
  76. p = plt.Rectangle((0, 0), 1, 1, fc=fc)
  77. plt_handlers.append(p)
  78. plt_titles.append(label_name)
  79. plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
  80. f = io.BytesIO()
  81. plt.savefig(f, bbox_inches='tight', pad_inches=0)
  82. plt.cla()
  83. plt.close()
  84. plt.switch_backend(backend_org)
  85. out_size = (img.shape[1], img.shape[0])
  86. out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
  87. out = np.asarray(out)
  88. return out
  89. def labelme_shapes_to_label(img_shape, shapes):
  90. label_name_to_val = {'background': 0}
  91. lbl = np.zeros(img_shape[:2], dtype=np.int32)
  92. for shape in shapes:
  93. polygons = shape['points']
  94. label_name = shape['label']
  95. if label_name in label_name_to_val:
  96. label_value = label_name_to_val[label_name]
  97. else:
  98. label_value = len(label_name_to_val)
  99. label_name_to_val[label_name] = label_value
  100. mask = polygons_to_mask(img_shape[:2], polygons)
  101. lbl[mask] = label_value
  102. lbl_names = [None] * (max(label_name_to_val.values()) + 1)
  103. for label_name, label_value in label_name_to_val.items():
  104. lbl_names[label_value] = label_name
  105. return lbl, lbl_names