|
@@ -100,3 +100,15 @@ def test_draw_label():
|
|
viz = labelme.utils.draw_label(lbl, img, label_names=label_names)
|
|
viz = labelme.utils.draw_label(lbl, img, label_names=label_names)
|
|
assert viz.shape[:2] == img.shape[:2] == lbl.shape[:2]
|
|
assert viz.shape[:2] == img.shape[:2] == lbl.shape[:2]
|
|
assert viz.dtype == np.uint8
|
|
assert viz.dtype == np.uint8
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def test_draw_instances():
|
|
|
|
+ img, lbl, label_names = _get_img_and_lbl()
|
|
|
|
+ labels_and_masks = {l: lbl == l for l in np.unique(lbl) if l != 0}
|
|
|
|
+ labels, masks = zip(*labels_and_masks.items())
|
|
|
|
+ masks = np.asarray(masks)
|
|
|
|
+ bboxes = labelme.utils.masks_to_bboxes(masks)
|
|
|
|
+ captions = [label_names[l] for l in labels]
|
|
|
|
+ viz = labelme.utils.draw_instances(img, bboxes, labels, captions=captions)
|
|
|
|
+ assert viz.shape[:2] == img.shape[:2]
|
|
|
|
+ assert viz.dtype == np.uint8
|