Răsfoiți Sursa

Add test for draw_instances

Kentaro Wada 6 ani în urmă
părinte
comite
7f57f09fb4
3 a modificat fișierele cu 33 adăugiri și 0 ștergeri
  1. 1 0
      labelme/utils/__init__.py
  2. 20 0
      labelme/utils/shape.py
  3. 12 0
      tests/test_utils.py

+ 1 - 0
labelme/utils/__init__.py

@@ -6,6 +6,7 @@ from .image import img_arr_to_b64
 from .image import img_b64_to_arr
 
 from .shape import labelme_shapes_to_label
+from .shape import masks_to_bboxes
 from .shape import polygons_to_mask
 from .shape import shape_to_mask
 from .shape import shapes_to_label

+ 20 - 0
labelme/utils/shape.py

@@ -77,3 +77,23 @@ def labelme_shapes_to_label(img_shape, shapes):
 
     lbl = shapes_to_label(img_shape, shapes, label_name_to_value)
     return lbl, label_name_to_value
+
+
+def masks_to_bboxes(masks):
+    if masks.ndim != 3:
+        raise ValueError(
+            'masks.ndim must be 3, but it is {}'
+            .format(masks.ndim)
+        )
+    if masks.dtype != bool:
+        raise ValueError(
+            'masks.dtype must be bool type, but it is {}'
+            .format(masks.dtype)
+        )
+    bboxes = []
+    for mask in masks:
+        where = np.argwhere(mask)
+        (y1, x1), (y2, x2) = where.min(0), where.max(0) + 1
+        bboxes.append((y1, x1, y2, x2))
+    bboxes = np.asarray(bboxes, dtype=np.float32)
+    return bboxes

+ 12 - 0
tests/test_utils.py

@@ -100,3 +100,15 @@ def test_draw_label():
     viz = labelme.utils.draw_label(lbl, img, label_names=label_names)
     assert viz.shape[:2] == img.shape[:2] == lbl.shape[:2]
     assert viz.dtype == np.uint8
+
+
+def test_draw_instances():
+    img, lbl, label_names = _get_img_and_lbl()
+    labels_and_masks = {l: lbl == l for l in np.unique(lbl) if l != 0}
+    labels, masks = zip(*labels_and_masks.items())
+    masks = np.asarray(masks)
+    bboxes = labelme.utils.masks_to_bboxes(masks)
+    captions = [label_names[l] for l in labels]
+    viz = labelme.utils.draw_instances(img, bboxes, labels, captions=captions)
+    assert viz.shape[:2] == img.shape[:2]
+    assert viz.dtype == np.uint8