shape.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from math import pow
  2. from math import sqrt
  3. import numpy as np
  4. import PIL.Image
  5. import PIL.ImageDraw
  6. from labelme import logger
  7. def polygons_to_mask(img_shape, polygons, shape_type=None):
  8. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  9. mask = PIL.Image.fromarray(mask)
  10. draw = PIL.ImageDraw.Draw(mask)
  11. if shape_type == "circle" and len(polygons) == 2:
  12. ((cx, cy), (px, py)) = polygons
  13. d = sqrt(pow(cx - px, 2) + pow(cy - py, 2))
  14. draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1)
  15. else:
  16. xy = list(map(tuple, polygons))
  17. draw.polygon(xy=xy, outline=1, fill=1)
  18. mask = np.array(mask, dtype=bool)
  19. return mask
  20. def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'):
  21. assert type in ['class', 'instance']
  22. cls = np.zeros(img_shape[:2], dtype=np.int32)
  23. if type == 'instance':
  24. ins = np.zeros(img_shape[:2], dtype=np.int32)
  25. instance_names = ['_background_']
  26. for shape in shapes:
  27. polygons = shape['points']
  28. label = shape['label']
  29. shape_type = shape.get('shape_type', None)
  30. if type == 'class':
  31. cls_name = label
  32. elif type == 'instance':
  33. cls_name = label.split('-')[0]
  34. if label not in instance_names:
  35. instance_names.append(label)
  36. ins_id = len(instance_names) - 1
  37. cls_id = label_name_to_value[cls_name]
  38. mask = polygons_to_mask(img_shape[:2], polygons, shape_type)
  39. cls[mask] = cls_id
  40. if type == 'instance':
  41. ins[mask] = ins_id
  42. if type == 'instance':
  43. return cls, ins
  44. return cls
  45. def labelme_shapes_to_label(img_shape, shapes):
  46. logger.warn('labelme_shapes_to_label is deprecated, so please use '
  47. 'shapes_to_label.')
  48. label_name_to_value = {'_background_': 0}
  49. for shape in shapes:
  50. label_name = shape['label']
  51. if label_name in label_name_to_value:
  52. label_value = label_name_to_value[label_name]
  53. else:
  54. label_value = len(label_name_to_value)
  55. label_name_to_value[label_name] = label_value
  56. lbl = shapes_to_label(img_shape, shapes, label_name_to_value)
  57. return lbl, label_name_to_value