shape.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import math
  2. import numpy as np
  3. import PIL.Image
  4. import PIL.ImageDraw
  5. from labelme import logger
  6. def polygons_to_mask(img_shape, polygons, shape_type=None):
  7. logger.warning(
  8. "The 'polygons_to_mask' function is deprecated, "
  9. "use 'shape_to_mask' instead."
  10. )
  11. return shape_to_mask(img_shape, points=polygons, shape_type=shape_type)
  12. def shape_to_mask(img_shape, points, shape_type=None):
  13. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  14. mask = PIL.Image.fromarray(mask)
  15. draw = PIL.ImageDraw.Draw(mask)
  16. if shape_type == 'circle' and len(points) == 2:
  17. (cx, cy), (px, py) = points
  18. d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2)
  19. draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1)
  20. elif shape_type == 'rectangle' and len(points) == 2:
  21. xy = [tuple(point) for point in points]
  22. draw.rectangle(xy, outline=1, fill=1)
  23. else:
  24. xy = [tuple(point) for point in points]
  25. draw.polygon(xy=xy, outline=1, fill=1)
  26. mask = np.array(mask, dtype=bool)
  27. return mask
  28. def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'):
  29. assert type in ['class', 'instance']
  30. cls = np.zeros(img_shape[:2], dtype=np.int32)
  31. if type == 'instance':
  32. ins = np.zeros(img_shape[:2], dtype=np.int32)
  33. instance_names = ['_background_']
  34. for shape in shapes:
  35. points = shape['points']
  36. label = shape['label']
  37. shape_type = shape.get('shape_type', None)
  38. if type == 'class':
  39. cls_name = label
  40. elif type == 'instance':
  41. cls_name = label.split('-')[0]
  42. if label not in instance_names:
  43. instance_names.append(label)
  44. ins_id = len(instance_names) - 1
  45. cls_id = label_name_to_value[cls_name]
  46. mask = shape_to_mask(img_shape[:2], points, shape_type)
  47. cls[mask] = cls_id
  48. if type == 'instance':
  49. ins[mask] = ins_id
  50. if type == 'instance':
  51. return cls, ins
  52. return cls
  53. def labelme_shapes_to_label(img_shape, shapes):
  54. logger.warn('labelme_shapes_to_label is deprecated, so please use '
  55. 'shapes_to_label.')
  56. label_name_to_value = {'_background_': 0}
  57. for shape in shapes:
  58. label_name = shape['label']
  59. if label_name in label_name_to_value:
  60. label_value = label_name_to_value[label_name]
  61. else:
  62. label_value = len(label_name_to_value)
  63. label_name_to_value[label_name] = label_value
  64. lbl = shapes_to_label(img_shape, shapes, label_name_to_value)
  65. return lbl, label_name_to_value