utils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import base64
  2. try:
  3. import io
  4. except ImportError:
  5. import io as io
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. import PIL.Image
  9. import PIL.ImageDraw
  10. import scipy.misc
  11. import skimage.color
  12. def labelcolormap(N=256):
  13. def bitget(byteval, idx):
  14. return ((byteval & (1 << idx)) != 0)
  15. cmap = np.zeros((N, 3))
  16. for i in range(0, N):
  17. id = i
  18. r, g, b = 0, 0, 0
  19. for j in range(0, 8):
  20. r = np.bitwise_or(r, (bitget(id, 0) << 7-j))
  21. g = np.bitwise_or(g, (bitget(id, 1) << 7-j))
  22. b = np.bitwise_or(b, (bitget(id, 2) << 7-j))
  23. id = (id >> 3)
  24. cmap[i, 0] = r
  25. cmap[i, 1] = g
  26. cmap[i, 2] = b
  27. cmap = cmap.astype(np.float32) / 255
  28. return cmap
  29. def img_b64_to_array(img_b64):
  30. f = io.BytesIO()
  31. f.write(base64.b64decode(img_b64))
  32. img_arr = np.array(PIL.Image.open(f))
  33. return img_arr
  34. def polygons_to_mask(img_shape, polygons):
  35. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  36. mask = PIL.Image.fromarray(mask)
  37. xy = list(map(tuple, polygons))
  38. PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)
  39. mask = np.array(mask, dtype=bool)
  40. return mask
  41. def draw_label(label, img, label_names, colormap=None):
  42. plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
  43. wspace=0, hspace=0)
  44. plt.margins(0, 0)
  45. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  46. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  47. if colormap is None:
  48. colormap = labelcolormap(len(label_names))
  49. label_viz = skimage.color.label2rgb(
  50. label, img, colors=colormap[1:], bg_label=0, bg_color=colormap[0])
  51. plt.imshow(label_viz)
  52. plt.axis('off')
  53. plt_handlers = []
  54. plt_titles = []
  55. for label_value, label_name in enumerate(label_names):
  56. fc = colormap[label_value]
  57. p = plt.Rectangle((0, 0), 1, 1, fc=fc)
  58. plt_handlers.append(p)
  59. plt_titles.append(label_name)
  60. plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
  61. f = io.BytesIO()
  62. plt.savefig(f, bbox_inches='tight', pad_inches=0)
  63. plt.cla()
  64. plt.close()
  65. out = np.array(PIL.Image.open(f))[:, :, :3]
  66. out = scipy.misc.imresize(out, img.shape[:2])
  67. return out
  68. def labelme_shapes_to_label(img_shape, shapes):
  69. label_name_to_val = {'background': 0}
  70. lbl = np.zeros(img_shape[:2], dtype=np.int32)
  71. for shape in sorted(shapes, key=lambda x: x['label']):
  72. polygons = shape['points']
  73. label_name = shape['label']
  74. if label_name in label_name_to_val:
  75. label_value = label_name_to_val[label_name]
  76. else:
  77. label_value = len(label_name_to_val)
  78. label_name_to_val[label_name] = label_value
  79. mask = polygons_to_mask(img_shape[:2], polygons)
  80. lbl[mask] = label_value
  81. lbl_names = [None] * (max(label_name_to_val.values()) + 1)
  82. for label_name, label_value in label_name_to_val.items():
  83. lbl_names[label_value] = label_name
  84. return lbl, lbl_names