shape.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # MIT License
  2. # Copyright (c) Kentaro Wada
  3. import math
  4. import uuid
  5. import numpy as np
  6. import PIL.Image
  7. import PIL.ImageDraw
  8. from labelme.logger import logger
  9. def polygons_to_mask(img_shape, polygons, shape_type=None):
  10. logger.warning(
  11. "The 'polygons_to_mask' function is deprecated, " "use 'shape_to_mask' instead."
  12. )
  13. return shape_to_mask(img_shape, points=polygons, shape_type=shape_type)
  14. def shape_to_mask(img_shape, points, shape_type=None, line_width=10, point_size=5):
  15. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  16. mask = PIL.Image.fromarray(mask)
  17. draw = PIL.ImageDraw.Draw(mask)
  18. xy = [tuple(point) for point in points]
  19. if shape_type == "circle":
  20. assert len(xy) == 2, "Shape of shape_type=circle must have 2 points"
  21. (cx, cy), (px, py) = xy
  22. d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2)
  23. draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1)
  24. elif shape_type == "rectangle":
  25. assert len(xy) == 2, "Shape of shape_type=rectangle must have 2 points"
  26. draw.rectangle(xy, outline=1, fill=1)
  27. elif shape_type == "line":
  28. assert len(xy) == 2, "Shape of shape_type=line must have 2 points"
  29. draw.line(xy=xy, fill=1, width=line_width)
  30. elif shape_type == "linestrip":
  31. draw.line(xy=xy, fill=1, width=line_width)
  32. elif shape_type == "point":
  33. assert len(xy) == 1, "Shape of shape_type=point must have 1 points"
  34. cx, cy = xy[0]
  35. r = point_size
  36. draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=1, fill=1)
  37. else:
  38. assert len(xy) > 2, "Polygon must have points more than 2"
  39. draw.polygon(xy=xy, outline=1, fill=1)
  40. mask = np.array(mask, dtype=bool)
  41. return mask
  42. def shapes_to_label(img_shape, shapes, label_name_to_value):
  43. cls = np.zeros(img_shape[:2], dtype=np.int32)
  44. ins = np.zeros_like(cls)
  45. instances = []
  46. for shape in shapes:
  47. points = shape["points"]
  48. label = shape["label"]
  49. group_id = shape.get("group_id")
  50. if group_id is None:
  51. group_id = uuid.uuid1()
  52. shape_type = shape.get("shape_type", None)
  53. cls_name = label
  54. instance = (cls_name, group_id)
  55. if instance not in instances:
  56. instances.append(instance)
  57. ins_id = instances.index(instance) + 1
  58. cls_id = label_name_to_value[cls_name]
  59. mask = shape_to_mask(img_shape[:2], points, shape_type)
  60. cls[mask] = cls_id
  61. ins[mask] = ins_id
  62. return cls, ins
  63. def labelme_shapes_to_label(img_shape, shapes):
  64. logger.warn(
  65. "labelme_shapes_to_label is deprecated, so please use " "shapes_to_label."
  66. )
  67. label_name_to_value = {"_background_": 0}
  68. for shape in shapes:
  69. label_name = shape["label"]
  70. if label_name in label_name_to_value:
  71. label_value = label_name_to_value[label_name]
  72. else:
  73. label_value = len(label_name_to_value)
  74. label_name_to_value[label_name] = label_value
  75. lbl, _ = shapes_to_label(img_shape, shapes, label_name_to_value)
  76. return lbl, label_name_to_value
  77. def masks_to_bboxes(masks):
  78. if masks.ndim != 3:
  79. raise ValueError("masks.ndim must be 3, but it is {}".format(masks.ndim))
  80. if masks.dtype != bool:
  81. raise ValueError(
  82. "masks.dtype must be bool type, but it is {}".format(masks.dtype)
  83. )
  84. bboxes = []
  85. for mask in masks:
  86. where = np.argwhere(mask)
  87. (y1, x1), (y2, x2) = where.min(0), where.max(0) + 1
  88. bboxes.append((y1, x1, y2, x2))
  89. bboxes = np.asarray(bboxes, dtype=np.float32)
  90. return bboxes