utils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import base64
  2. import cStringIO as StringIO
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import PIL.Image
  6. import PIL.ImageDraw
  7. import scipy.misc
  8. import skimage.color
  9. def labelcolormap(N=256):
  10. def bitget(byteval, idx):
  11. return ((byteval & (1 << idx)) != 0)
  12. cmap = np.zeros((N, 3))
  13. for i in xrange(0, N):
  14. id = i
  15. r, g, b = 0, 0, 0
  16. for j in xrange(0, 8):
  17. r = np.bitwise_or(r, (bitget(id, 0) << 7-j))
  18. g = np.bitwise_or(g, (bitget(id, 1) << 7-j))
  19. b = np.bitwise_or(b, (bitget(id, 2) << 7-j))
  20. id = (id >> 3)
  21. cmap[i, 0] = r
  22. cmap[i, 1] = g
  23. cmap[i, 2] = b
  24. cmap = cmap.astype(np.float32) / 255
  25. return cmap
  26. def img_b64_to_array(img_b64):
  27. f = StringIO.StringIO()
  28. f.write(base64.b64decode(img_b64))
  29. img_arr = np.array(PIL.Image.open(f))
  30. return img_arr
  31. def polygons_to_mask(img_shape, polygons):
  32. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  33. mask = PIL.Image.fromarray(mask)
  34. xy = map(tuple, polygons)
  35. PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)
  36. mask = np.array(mask, dtype=bool)
  37. return mask
  38. def draw_label(label, img, label_names, colormap=None):
  39. plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
  40. wspace=0, hspace=0)
  41. plt.margins(0, 0)
  42. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  43. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  44. if colormap is None:
  45. colormap = labelcolormap(len(label_names))
  46. label_viz = skimage.color.label2rgb(
  47. label, img, colors=colormap[1:], bg_label=0, bg_color=colormap[0])
  48. plt.imshow(label_viz)
  49. plt.axis('off')
  50. plt_handlers = []
  51. plt_titles = []
  52. for label_value, label_name in enumerate(label_names):
  53. fc = colormap[label_value]
  54. p = plt.Rectangle((0, 0), 1, 1, fc=fc)
  55. plt_handlers.append(p)
  56. plt_titles.append(label_name)
  57. plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
  58. f = StringIO.StringIO()
  59. plt.savefig(f, bbox_inches='tight', pad_inches=0)
  60. plt.cla()
  61. plt.close()
  62. out = np.array(PIL.Image.open(f))[:, :, :3]
  63. out = scipy.misc.imresize(out, img.shape[:2])
  64. return out
  65. def labelme_shapes_to_label(img_shape, shapes):
  66. label_name_to_val = {'background': 0}
  67. lbl = np.zeros(img_shape[:2], dtype=np.int32)
  68. for shape in sorted(shapes, key=lambda x: x['label']):
  69. polygons = shape['points']
  70. label_name = shape['label']
  71. if label_name in label_name_to_val:
  72. label_value = label_name_to_val[label_name]
  73. else:
  74. label_value = len(label_name_to_val)
  75. label_name_to_val[label_name] = label_value
  76. mask = polygons_to_mask(img_shape[:2], polygons)
  77. lbl[mask] = label_value
  78. lbl_names = [None] * (max(label_name_to_val.values()) + 1)
  79. for label_name, label_value in label_name_to_val.items():
  80. lbl_names[label_value] = label_name
  81. return lbl, lbl_names