shape.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import math
  2. import numpy as np
  3. import PIL.Image
  4. import PIL.ImageDraw
  5. from labelme.logger import logger
  6. def polygons_to_mask(img_shape, polygons, shape_type=None):
  7. logger.warning(
  8. "The 'polygons_to_mask' function is deprecated, "
  9. "use 'shape_to_mask' instead."
  10. )
  11. return shape_to_mask(img_shape, points=polygons, shape_type=shape_type)
  12. def shape_to_mask(img_shape, points, shape_type=None,
  13. line_width=10, point_size=5):
  14. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  15. mask = PIL.Image.fromarray(mask)
  16. draw = PIL.ImageDraw.Draw(mask)
  17. xy = [tuple(point) for point in points]
  18. if shape_type == 'circle':
  19. assert len(xy) == 2, 'Shape of shape_type=circle must have 2 points'
  20. (cx, cy), (px, py) = xy
  21. d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2)
  22. draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1)
  23. elif shape_type == 'rectangle':
  24. assert len(xy) == 2, 'Shape of shape_type=rectangle must have 2 points'
  25. draw.rectangle(xy, outline=1, fill=1)
  26. elif shape_type == 'line':
  27. assert len(xy) == 2, 'Shape of shape_type=line must have 2 points'
  28. draw.line(xy=xy, fill=1, width=line_width)
  29. elif shape_type == 'linestrip':
  30. draw.line(xy=xy, fill=1, width=line_width)
  31. elif shape_type == 'point':
  32. assert len(xy) == 1, 'Shape of shape_type=point must have 1 points'
  33. cx, cy = xy[0]
  34. r = point_size
  35. draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=1, fill=1)
  36. else:
  37. assert len(xy) > 2, 'Polygon must have points more than 2'
  38. draw.polygon(xy=xy, outline=1, fill=1)
  39. mask = np.array(mask, dtype=bool)
  40. return mask
  41. def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'):
  42. assert type in ['class', 'instance']
  43. cls = np.zeros(img_shape[:2], dtype=np.int32)
  44. if type == 'instance':
  45. ins = np.zeros(img_shape[:2], dtype=np.int32)
  46. instance_names = ['_background_']
  47. for shape in shapes:
  48. points = shape['points']
  49. label = shape['label']
  50. shape_type = shape.get('shape_type', None)
  51. if type == 'class':
  52. cls_name = label
  53. elif type == 'instance':
  54. cls_name = label.split('-')[0]
  55. if label not in instance_names:
  56. instance_names.append(label)
  57. ins_id = instance_names.index(label)
  58. cls_id = label_name_to_value[cls_name]
  59. mask = shape_to_mask(img_shape[:2], points, shape_type)
  60. cls[mask] = cls_id
  61. if type == 'instance':
  62. ins[mask] = ins_id
  63. if type == 'instance':
  64. return cls, ins
  65. return cls
  66. def labelme_shapes_to_label(img_shape, shapes):
  67. logger.warn('labelme_shapes_to_label is deprecated, so please use '
  68. 'shapes_to_label.')
  69. label_name_to_value = {'_background_': 0}
  70. for shape in shapes:
  71. label_name = shape['label']
  72. if label_name in label_name_to_value:
  73. label_value = label_name_to_value[label_name]
  74. else:
  75. label_value = len(label_name_to_value)
  76. label_name_to_value[label_name] = label_value
  77. lbl = shapes_to_label(img_shape, shapes, label_name_to_value)
  78. return lbl, label_name_to_value
  79. def masks_to_bboxes(masks):
  80. if masks.ndim != 3:
  81. raise ValueError(
  82. 'masks.ndim must be 3, but it is {}'
  83. .format(masks.ndim)
  84. )
  85. if masks.dtype != bool:
  86. raise ValueError(
  87. 'masks.dtype must be bool type, but it is {}'
  88. .format(masks.dtype)
  89. )
  90. bboxes = []
  91. for mask in masks:
  92. where = np.argwhere(mask)
  93. (y1, x1), (y2, x2) = where.min(0), where.max(0) + 1
  94. bboxes.append((y1, x1, y2, x2))
  95. bboxes = np.asarray(bboxes, dtype=np.float32)
  96. return bboxes