浏览代码

Add tests for labelme/utils.py

Kentaro Wada 7 年之前
父节点
当前提交
dc64f8d8de
共有 1 个文件被更改,包括 75 次插入1 次删除
  1. 75 1
      tests/test_utils.py

+ 75 - 1
tests/test_utils.py

@@ -11,11 +11,38 @@ here = osp.dirname(osp.abspath(__file__))
 data_dir = osp.join(here, 'data')
 
 
-def test_img_b64_to_arr():
+def _get_img_and_data():
     json_file = osp.join(data_dir, 'apc2016_obj3.json')
     data = json.load(open(json_file))
     img_b64 = data['imageData']
     img = labelme.utils.img_b64_to_arr(img_b64)
+    return img, data
+
+
+def _get_img_and_lbl():
+    img, data = _get_img_and_data()
+
+    label_name_to_value = {'__background__': 0}
+    for shape in data['shapes']:
+        label_name = shape['label']
+        label_value = len(label_name_to_value)
+        label_name_to_value[label_name] = label_value
+
+    n_labels = max(label_name_to_value.values()) + 1
+    label_names = [None] * n_labels
+    for label_name, label_value in label_name_to_value.items():
+        label_names[label_value] = label_name
+
+    lbl = labelme.utils.shapes_to_label(
+        img.shape, data['shapes'], label_name_to_value)
+    return img, lbl, label_names
+
+
+# -----------------------------------------------------------------------------
+
+
+def test_img_b64_to_arr():
+    img, _ = _get_img_and_data()
     assert img.dtype == np.uint8
     assert img.shape == (907, 1210, 3)
 
@@ -26,3 +53,50 @@ def test_img_arr_to_b64():
     img_b64 = labelme.utils.img_arr_to_b64(img_arr)
     img_arr2 = labelme.utils.img_b64_to_arr(img_b64)
     np.testing.assert_allclose(img_arr, img_arr2)
+
+
+def test_shapes_to_label():
+    img, data = _get_img_and_data()
+    label_name_to_value = {}
+    for shape in data['shapes']:
+        label_name = shape['label']
+        label_value = len(label_name_to_value)
+        label_name_to_value[label_name] = label_value
+    cls = labelme.utils.shapes_to_label(
+        img.shape, data['shapes'], label_name_to_value)
+    assert cls.shape == img.shape[:2]
+
+
+def test_polygons_to_mask():
+    img, data = _get_img_and_data()
+    for shape in data['shapes']:
+        polygons = shape['points']
+        mask = labelme.utils.polygons_to_mask(img.shape[:2], polygons)
+        assert mask.shape == img.shape[:2]
+
+
+def test_label_colormap():
+    N = 255
+    colormap = labelme.utils.label_colormap(N=N)
+    assert colormap.shape == (N, 3)
+
+
+def test_label2rgb():
+    img, lbl, label_names = _get_img_and_lbl()
+    n_labels = len(label_names)
+
+    viz = labelme.utils.label2rgb(lbl=lbl, n_labels=n_labels)
+    assert lbl.shape == viz.shape[:2]
+    assert viz.dtype == np.uint8
+
+    viz = labelme.utils.label2rgb(lbl=lbl, img=img, n_labels=n_labels)
+    assert img.shape[:2] == lbl.shape == viz.shape[:2]
+    assert viz.dtype == np.uint8
+
+
+def test_draw_label():
+    img, lbl, label_names = _get_img_and_lbl()
+
+    viz = labelme.utils.draw_label(lbl, img, label_names=label_names)
+    assert viz.shape[:2] == img.shape[:2] == lbl.shape[:2]
+    assert viz.dtype == np.uint8