test_draw.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import numpy as np
  2. from labelme.utils import draw as draw_module
  3. from labelme.utils import shape as shape_module
  4. from .util import get_img_and_lbl
  5. # -----------------------------------------------------------------------------
  6. def test_label_colormap():
  7. N = 255
  8. colormap = draw_module.label_colormap(N=N)
  9. assert colormap.shape == (N, 3)
  10. def test_label2rgb():
  11. img, lbl, label_names = get_img_and_lbl()
  12. n_labels = len(label_names)
  13. viz = draw_module.label2rgb(lbl=lbl, n_labels=n_labels)
  14. assert lbl.shape == viz.shape[:2]
  15. assert viz.dtype == np.uint8
  16. viz = draw_module.label2rgb(lbl=lbl, img=img, n_labels=n_labels)
  17. assert img.shape[:2] == lbl.shape == viz.shape[:2]
  18. assert viz.dtype == np.uint8
  19. def test_draw_label():
  20. img, lbl, label_names = get_img_and_lbl()
  21. viz = draw_module.draw_label(lbl, img, label_names=label_names)
  22. assert viz.shape[:2] == img.shape[:2] == lbl.shape[:2]
  23. assert viz.dtype == np.uint8
  24. def test_draw_instances():
  25. img, lbl, label_names = get_img_and_lbl()
  26. labels_and_masks = {l: lbl == l for l in np.unique(lbl) if l != 0}
  27. labels, masks = zip(*labels_and_masks.items())
  28. masks = np.asarray(masks)
  29. bboxes = shape_module.masks_to_bboxes(masks)
  30. captions = [label_names[l] for l in labels]
  31. viz = draw_module.draw_instances(img, bboxes, labels, captions=captions)
  32. assert viz.shape[:2] == img.shape[:2]
  33. assert viz.dtype == np.uint8