utils.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import base64
  2. import cStringIO as StringIO
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import PIL.Image
  6. import PIL.ImageDraw
  7. import scipy.misc
  8. import skimage.color
  9. def labelcolormap(N=256):
  10. def bitget(byteval, idx):
  11. return ((byteval & (1 << idx)) != 0)
  12. cmap = np.zeros((N, 3))
  13. for i in xrange(0, N):
  14. id = i
  15. r, g, b = 0, 0, 0
  16. for j in xrange(0, 8):
  17. r = np.bitwise_or(r, (bitget(id, 0) << 7-j))
  18. g = np.bitwise_or(g, (bitget(id, 1) << 7-j))
  19. b = np.bitwise_or(b, (bitget(id, 2) << 7-j))
  20. id = (id >> 3)
  21. cmap[i, 0] = r
  22. cmap[i, 1] = g
  23. cmap[i, 2] = b
  24. cmap = cmap.astype(np.float32) / 255
  25. return cmap
  26. def img_b64_to_array(img_b64):
  27. f = StringIO.StringIO()
  28. f.write(base64.b64decode(img_b64))
  29. img_arr = np.array(PIL.Image.open(f))
  30. return img_arr
  31. def polygons_to_mask(img_shape, polygons):
  32. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  33. mask = PIL.Image.fromarray(mask)
  34. xy = map(tuple, polygons)
  35. PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)
  36. mask = np.array(mask, dtype=bool)
  37. return mask
  38. def draw_label(label, img, label_names):
  39. plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
  40. wspace=0, hspace=0)
  41. plt.margins(0, 0)
  42. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  43. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  44. cmap = labelcolormap(len(label_names))
  45. label_viz = skimage.color.label2rgb(
  46. label, img, colors=cmap[1:], bg_label=0)
  47. plt.imshow(label_viz)
  48. plt.axis('off')
  49. plt_handlers = []
  50. plt_titles = []
  51. for label_value, label_name in enumerate(label_names):
  52. fc = cmap[label_value]
  53. p = plt.Rectangle((0, 0), 1, 1, fc=fc)
  54. plt_handlers.append(p)
  55. plt_titles.append(label_name)
  56. plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
  57. f = StringIO.StringIO()
  58. plt.savefig(f, bbox_inches='tight', pad_inches=0)
  59. plt.cla()
  60. plt.close()
  61. out = np.array(PIL.Image.open(f))[:, :, :3]
  62. out = scipy.misc.imresize(out, img.shape[:2])
  63. return out
  64. def labelme_shapes_to_label(img_shape, shapes):
  65. label_name_to_val = {'background': 0}
  66. lbl = np.zeros(img_shape[:2], dtype=np.int32)
  67. for shape in sorted(shapes, key=lambda x: x['label']):
  68. polygons = shape['points']
  69. label_name = shape['label']
  70. if label_name in label_name_to_val:
  71. label_value = label_name_to_val[label_name]
  72. else:
  73. label_value = len(label_name_to_val)
  74. label_name_to_val[label_name] = label_value
  75. mask = polygons_to_mask(img_shape[:2], polygons)
  76. lbl[mask] = label_value
  77. lbl_names = [None] * (max(label_name_to_val.values()) + 1)
  78. for label_name, label_value in label_name_to_val.items():
  79. lbl_names[label_value] = label_name
  80. return lbl, lbl_names