utils.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import base64
  2. import cStringIO as StringIO
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import PIL.Image
  6. import PIL.ImageDraw
  7. import skimage.color
  8. def labelcolormap(N=256):
  9. def bitget(byteval, idx):
  10. return ((byteval & (1 << idx)) != 0)
  11. cmap = np.zeros((N, 3))
  12. for i in xrange(0, N):
  13. id = i
  14. r, g, b = 0, 0, 0
  15. for j in xrange(0, 8):
  16. r = np.bitwise_or(r, (bitget(id, 0) << 7-j))
  17. g = np.bitwise_or(g, (bitget(id, 1) << 7-j))
  18. b = np.bitwise_or(b, (bitget(id, 2) << 7-j))
  19. id = (id >> 3)
  20. cmap[i, 0] = r
  21. cmap[i, 1] = g
  22. cmap[i, 2] = b
  23. cmap = cmap.astype(np.float32) / 255
  24. return cmap
  25. def img_b64_to_array(img_b64):
  26. f = StringIO.StringIO()
  27. f.write(base64.b64decode(img_b64))
  28. img_arr = np.array(PIL.Image.open(f))
  29. return img_arr
  30. def polygons_to_mask(img_shape, polygons):
  31. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  32. mask = PIL.Image.fromarray(mask)
  33. xy = map(tuple, polygons)
  34. PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)
  35. mask = np.array(mask, dtype=bool)
  36. return mask
  37. def draw_label(label, img, label_names):
  38. plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
  39. wspace=0, hspace=0)
  40. plt.margins(0, 0)
  41. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  42. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  43. cmap = labelcolormap(len(label_names))
  44. label_viz = skimage.color.label2rgb(label, img, bg_label=0)
  45. plt.imshow(label_viz)
  46. plt_handlers = []
  47. plt_titles = []
  48. for label_value, label_name in enumerate(label_names):
  49. fc = cmap[label_value]
  50. p = plt.Rectangle((0, 0), 1, 1, fc=fc)
  51. plt_handlers.append(p)
  52. plt_titles.append(label_name)
  53. plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
  54. f = StringIO.StringIO()
  55. plt.savefig(f, bbox_inches='tight', pad_inches=0)
  56. plt.cla()
  57. img = np.array(PIL.Image.open(f))[:, :, :3]
  58. return img
  59. def labelme_shapes_to_label(img_shape, shapes):
  60. label_name_to_val = {'background': 0}
  61. lbl = np.zeros(img_shape[:2], dtype=np.int32)
  62. for shape in sorted(shapes, key=lambda x: x['label']):
  63. polygons = shape['points']
  64. label_name = shape['label']
  65. if label_name in label_name_to_val:
  66. label_value = label_name_to_val[label_name]
  67. else:
  68. label_value = len(label_name_to_val)
  69. label_name_to_val[label_name] = label_value
  70. mask = polygons_to_mask(img_shape[:2], polygons)
  71. lbl[mask] = label_value
  72. lbl_names = [None] * (max(label_name_to_val.values()) + 1)
  73. for label_name, label_value in label_name_to_val.items():
  74. lbl_names[label_value] = label_name
  75. return lbl, lbl_names