test_utils.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import json
  2. import os.path as osp
  3. import numpy as np
  4. import PIL.Image
  5. import labelme
  6. here = osp.dirname(osp.abspath(__file__))
  7. data_dir = osp.join(here, 'data')
  8. def _get_img_and_data():
  9. json_file = osp.join(data_dir, 'apc2016_obj3.json')
  10. data = json.load(open(json_file))
  11. img_b64 = data['imageData']
  12. img = labelme.utils.img_b64_to_arr(img_b64)
  13. return img, data
  14. def _get_img_and_lbl():
  15. img, data = _get_img_and_data()
  16. label_name_to_value = {'__background__': 0}
  17. for shape in data['shapes']:
  18. label_name = shape['label']
  19. label_value = len(label_name_to_value)
  20. label_name_to_value[label_name] = label_value
  21. n_labels = max(label_name_to_value.values()) + 1
  22. label_names = [None] * n_labels
  23. for label_name, label_value in label_name_to_value.items():
  24. label_names[label_value] = label_name
  25. lbl = labelme.utils.shapes_to_label(
  26. img.shape, data['shapes'], label_name_to_value)
  27. return img, lbl, label_names
  28. # -----------------------------------------------------------------------------
  29. def test_img_b64_to_arr():
  30. img, _ = _get_img_and_data()
  31. assert img.dtype == np.uint8
  32. assert img.shape == (907, 1210, 3)
  33. def test_img_arr_to_b64():
  34. img_file = osp.join(data_dir, 'apc2016_obj3.jpg')
  35. img_arr = np.asarray(PIL.Image.open(img_file))
  36. img_b64 = labelme.utils.img_arr_to_b64(img_arr)
  37. img_arr2 = labelme.utils.img_b64_to_arr(img_b64)
  38. np.testing.assert_allclose(img_arr, img_arr2)
  39. def test_shapes_to_label():
  40. img, data = _get_img_and_data()
  41. label_name_to_value = {}
  42. for shape in data['shapes']:
  43. label_name = shape['label']
  44. label_value = len(label_name_to_value)
  45. label_name_to_value[label_name] = label_value
  46. cls = labelme.utils.shapes_to_label(
  47. img.shape, data['shapes'], label_name_to_value)
  48. assert cls.shape == img.shape[:2]
  49. def test_polygons_to_mask():
  50. img, data = _get_img_and_data()
  51. for shape in data['shapes']:
  52. polygons = shape['points']
  53. mask = labelme.utils.polygons_to_mask(img.shape[:2], polygons)
  54. assert mask.shape == img.shape[:2]
  55. def test_label_colormap():
  56. N = 255
  57. colormap = labelme.utils.label_colormap(N=N)
  58. assert colormap.shape == (N, 3)
  59. def test_label2rgb():
  60. img, lbl, label_names = _get_img_and_lbl()
  61. n_labels = len(label_names)
  62. viz = labelme.utils.label2rgb(lbl=lbl, n_labels=n_labels)
  63. assert lbl.shape == viz.shape[:2]
  64. assert viz.dtype == np.uint8
  65. viz = labelme.utils.label2rgb(lbl=lbl, img=img, n_labels=n_labels)
  66. assert img.shape[:2] == lbl.shape == viz.shape[:2]
  67. assert viz.dtype == np.uint8
  68. def test_draw_label():
  69. img, lbl, label_names = _get_img_and_lbl()
  70. viz = labelme.utils.draw_label(lbl, img, label_names=label_names)
  71. assert viz.shape[:2] == img.shape[:2] == lbl.shape[:2]
  72. assert viz.dtype == np.uint8
  73. def test_draw_instances():
  74. img, lbl, label_names = _get_img_and_lbl()
  75. labels_and_masks = {l: lbl == l for l in np.unique(lbl) if l != 0}
  76. labels, masks = zip(*labels_and_masks.items())
  77. masks = np.asarray(masks)
  78. bboxes = labelme.utils.masks_to_bboxes(masks)
  79. captions = [label_names[l] for l in labels]
  80. viz = labelme.utils.draw_instances(img, bboxes, labels, captions=captions)
  81. assert viz.shape[:2] == img.shape[:2]
  82. assert viz.dtype == np.uint8