shape.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. import copy
  2. import math
  3. import numpy as np
  4. import skimage.measure
  5. from qtpy import QtCore
  6. from qtpy import QtGui
  7. import labelme.utils
  8. from labelme.logger import logger
  9. # TODO(unknown):
  10. # - [opt] Store paths instead of creating new ones at each paint.
  11. class Shape(object):
  12. # Render handles as squares
  13. P_SQUARE = 0
  14. # Render handles as circles
  15. P_ROUND = 1
  16. # Flag for the handles we would move if dragging
  17. MOVE_VERTEX = 0
  18. # Flag for all other handles on the current shape
  19. NEAR_VERTEX = 1
  20. # The following class variables influence the drawing of all shape objects.
  21. line_color = None
  22. fill_color = None
  23. select_line_color = None
  24. select_fill_color = None
  25. vertex_fill_color = None
  26. hvertex_fill_color = None
  27. point_type = P_ROUND
  28. point_size = 8
  29. scale = 1.0
  30. def __init__(
  31. self,
  32. label=None,
  33. line_color=None,
  34. shape_type=None,
  35. flags=None,
  36. group_id=None,
  37. description=None,
  38. mask=None,
  39. ):
  40. self.label = label
  41. self.group_id = group_id
  42. self.points = []
  43. self.point_labels = []
  44. self.shape_type = shape_type
  45. self._shape_raw = None
  46. self._points_raw = []
  47. self._shape_type_raw = None
  48. self.fill = False
  49. self.selected = False
  50. self.flags = flags
  51. self.description = description
  52. self.other_data = {}
  53. self.mask = mask
  54. self._highlightIndex = None
  55. self._highlightMode = self.NEAR_VERTEX
  56. self._highlightSettings = {
  57. self.NEAR_VERTEX: (4, self.P_ROUND),
  58. self.MOVE_VERTEX: (1.5, self.P_SQUARE),
  59. }
  60. self._closed = False
  61. if line_color is not None:
  62. # Override the class line_color attribute
  63. # with an object attribute. Currently this
  64. # is used for drawing the pending line a different color.
  65. self.line_color = line_color
  66. def setShapeRefined(self, shape_type, points, point_labels, mask=None):
  67. self._shape_raw = (self.shape_type, self.points, self.point_labels)
  68. self.shape_type = shape_type
  69. self.points = points
  70. self.point_labels = point_labels
  71. self.mask = mask
  72. def restoreShapeRaw(self):
  73. if self._shape_raw is None:
  74. return
  75. self.shape_type, self.points, self.point_labels = self._shape_raw
  76. self._shape_raw = None
  77. @property
  78. def shape_type(self):
  79. return self._shape_type
  80. @shape_type.setter
  81. def shape_type(self, value):
  82. if value is None:
  83. value = "polygon"
  84. if value not in [
  85. "polygon",
  86. "rectangle",
  87. "point",
  88. "line",
  89. "circle",
  90. "linestrip",
  91. "points",
  92. "mask",
  93. ]:
  94. raise ValueError("Unexpected shape_type: {}".format(value))
  95. self._shape_type = value
  96. def close(self):
  97. self._closed = True
  98. def addPoint(self, point, label=1):
  99. if self.points and point == self.points[0]:
  100. self.close()
  101. else:
  102. self.points.append(point)
  103. self.point_labels.append(label)
  104. def canAddPoint(self):
  105. return self.shape_type in ["polygon", "linestrip"]
  106. def popPoint(self):
  107. if self.points:
  108. if self.point_labels:
  109. self.point_labels.pop()
  110. return self.points.pop()
  111. return None
  112. def insertPoint(self, i, point, label=1):
  113. self.points.insert(i, point)
  114. self.point_labels.insert(i, label)
  115. def removePoint(self, i):
  116. if not self.canAddPoint():
  117. logger.warning(
  118. "Cannot remove point from: shape_type=%r",
  119. self.shape_type,
  120. )
  121. return
  122. if self.shape_type == "polygon" and len(self.points) <= 3:
  123. logger.warning(
  124. "Cannot remove point from: shape_type=%r, len(points)=%d",
  125. self.shape_type,
  126. len(self.points),
  127. )
  128. return
  129. if self.shape_type == "linestrip" and len(self.points) <= 2:
  130. logger.warning(
  131. "Cannot remove point from: shape_type=%r, len(points)=%d",
  132. self.shape_type,
  133. len(self.points),
  134. )
  135. return
  136. self.points.pop(i)
  137. self.point_labels.pop(i)
  138. def isClosed(self):
  139. return self._closed
  140. def setOpen(self):
  141. self._closed = False
  142. def getRectFromLine(self, pt1, pt2):
  143. x1, y1 = pt1.x(), pt1.y()
  144. x2, y2 = pt2.x(), pt2.y()
  145. return QtCore.QRectF(x1, y1, x2 - x1, y2 - y1)
  146. def paint(self, painter):
  147. if self.mask is None and not self.points:
  148. return
  149. color = self.select_line_color if self.selected else self.line_color
  150. pen = QtGui.QPen(color)
  151. # Try using integer sizes for smoother drawing(?)
  152. pen.setWidth(max(1, int(round(2.0 / self.scale))))
  153. painter.setPen(pen)
  154. if self.mask is not None:
  155. image_to_draw = np.zeros(self.mask.shape + (4,), dtype=np.uint8)
  156. fill_color = (
  157. self.select_fill_color.getRgb()
  158. if self.selected
  159. else self.fill_color.getRgb()
  160. )
  161. image_to_draw[self.mask] = fill_color
  162. qimage = QtGui.QImage.fromData(labelme.utils.img_arr_to_data(image_to_draw))
  163. painter.drawImage(
  164. int(round(self.points[0].x())),
  165. int(round(self.points[0].y())),
  166. qimage,
  167. )
  168. line_path = QtGui.QPainterPath()
  169. contours = skimage.measure.find_contours(np.pad(self.mask, pad_width=1))
  170. for contour in contours:
  171. contour += [self.points[0].y(), self.points[0].x()]
  172. line_path.moveTo(contour[0, 1], contour[0, 0])
  173. for point in contour[1:]:
  174. line_path.lineTo(point[1], point[0])
  175. painter.drawPath(line_path)
  176. if self.points:
  177. line_path = QtGui.QPainterPath()
  178. vrtx_path = QtGui.QPainterPath()
  179. negative_vrtx_path = QtGui.QPainterPath()
  180. if self.shape_type in ["rectangle", "mask"]:
  181. assert len(self.points) in [1, 2]
  182. if len(self.points) == 2:
  183. rectangle = self.getRectFromLine(*self.points)
  184. line_path.addRect(rectangle)
  185. if self.shape_type == "rectangle":
  186. for i in range(len(self.points)):
  187. self.drawVertex(vrtx_path, i)
  188. elif self.shape_type == "circle":
  189. assert len(self.points) in [1, 2]
  190. if len(self.points) == 2:
  191. rectangle = self.getCircleRectFromLine(self.points)
  192. line_path.addEllipse(rectangle)
  193. for i in range(len(self.points)):
  194. self.drawVertex(vrtx_path, i)
  195. elif self.shape_type == "linestrip":
  196. line_path.moveTo(self.points[0])
  197. for i, p in enumerate(self.points):
  198. line_path.lineTo(p)
  199. self.drawVertex(vrtx_path, i)
  200. elif self.shape_type == "points":
  201. assert len(self.points) == len(self.point_labels)
  202. for i, point_label in enumerate(self.point_labels):
  203. if point_label == 1:
  204. self.drawVertex(vrtx_path, i)
  205. else:
  206. self.drawVertex(negative_vrtx_path, i)
  207. else:
  208. line_path.moveTo(self.points[0])
  209. # Uncommenting the following line will draw 2 paths
  210. # for the 1st vertex, and make it non-filled, which
  211. # may be desirable.
  212. # self.drawVertex(vrtx_path, 0)
  213. for i, p in enumerate(self.points):
  214. line_path.lineTo(p)
  215. self.drawVertex(vrtx_path, i)
  216. if self.isClosed():
  217. line_path.lineTo(self.points[0])
  218. painter.drawPath(line_path)
  219. if vrtx_path.length() > 0:
  220. painter.drawPath(vrtx_path)
  221. painter.fillPath(vrtx_path, self._vertex_fill_color)
  222. if self.fill and self.mask is None:
  223. color = self.select_fill_color if self.selected else self.fill_color
  224. painter.fillPath(line_path, color)
  225. pen.setColor(QtGui.QColor(255, 0, 0, 255))
  226. painter.setPen(pen)
  227. painter.drawPath(negative_vrtx_path)
  228. painter.fillPath(negative_vrtx_path, QtGui.QColor(255, 0, 0, 255))
  229. def drawVertex(self, path, i):
  230. d = self.point_size / self.scale
  231. shape = self.point_type
  232. point = self.points[i]
  233. if i == self._highlightIndex:
  234. size, shape = self._highlightSettings[self._highlightMode]
  235. d *= size
  236. if self._highlightIndex is not None:
  237. self._vertex_fill_color = self.hvertex_fill_color
  238. else:
  239. self._vertex_fill_color = self.vertex_fill_color
  240. if shape == self.P_SQUARE:
  241. path.addRect(point.x() - d / 2, point.y() - d / 2, d, d)
  242. elif shape == self.P_ROUND:
  243. path.addEllipse(point, d / 2.0, d / 2.0)
  244. else:
  245. assert False, "unsupported vertex shape"
  246. def nearestVertex(self, point, epsilon):
  247. min_distance = float("inf")
  248. min_i = None
  249. for i, p in enumerate(self.points):
  250. dist = labelme.utils.distance(p - point)
  251. if dist <= epsilon and dist < min_distance:
  252. min_distance = dist
  253. min_i = i
  254. return min_i
  255. def nearestEdge(self, point, epsilon):
  256. min_distance = float("inf")
  257. post_i = None
  258. for i in range(len(self.points)):
  259. line = [self.points[i - 1], self.points[i]]
  260. dist = labelme.utils.distancetoline(point, line)
  261. if dist <= epsilon and dist < min_distance:
  262. min_distance = dist
  263. post_i = i
  264. return post_i
  265. def containsPoint(self, point):
  266. if self.mask is not None:
  267. y = np.clip(
  268. int(round(point.y() - self.points[0].y())),
  269. 0,
  270. self.mask.shape[0] - 1,
  271. )
  272. x = np.clip(
  273. int(round(point.x() - self.points[0].x())),
  274. 0,
  275. self.mask.shape[1] - 1,
  276. )
  277. return self.mask[y, x]
  278. return self.makePath().contains(point)
  279. def getCircleRectFromLine(self, line):
  280. """Computes parameters to draw with `QPainterPath::addEllipse`"""
  281. if len(line) != 2:
  282. return None
  283. (c, point) = line
  284. r = line[0] - line[1]
  285. d = math.sqrt(math.pow(r.x(), 2) + math.pow(r.y(), 2))
  286. rectangle = QtCore.QRectF(c.x() - d, c.y() - d, 2 * d, 2 * d)
  287. return rectangle
  288. def makePath(self):
  289. if self.shape_type in ["rectangle", "mask"]:
  290. path = QtGui.QPainterPath()
  291. if len(self.points) == 2:
  292. rectangle = self.getRectFromLine(*self.points)
  293. path.addRect(rectangle)
  294. elif self.shape_type == "circle":
  295. path = QtGui.QPainterPath()
  296. if len(self.points) == 2:
  297. rectangle = self.getCircleRectFromLine(self.points)
  298. path.addEllipse(rectangle)
  299. else:
  300. path = QtGui.QPainterPath(self.points[0])
  301. for p in self.points[1:]:
  302. path.lineTo(p)
  303. return path
  304. def boundingRect(self):
  305. return self.makePath().boundingRect()
  306. def moveBy(self, offset):
  307. self.points = [p + offset for p in self.points]
  308. def moveVertexBy(self, i, offset):
  309. self.points[i] = self.points[i] + offset
  310. def highlightVertex(self, i, action):
  311. """Highlight a vertex appropriately based on the current action
  312. Args:
  313. i (int): The vertex index
  314. action (int): The action
  315. (see Shape.NEAR_VERTEX and Shape.MOVE_VERTEX)
  316. """
  317. self._highlightIndex = i
  318. self._highlightMode = action
  319. def highlightClear(self):
  320. """Clear the highlighted point"""
  321. self._highlightIndex = None
  322. def copy(self):
  323. return copy.deepcopy(self)
  324. def __len__(self):
  325. return len(self.points)
  326. def __getitem__(self, key):
  327. return self.points[key]
  328. def __setitem__(self, key, value):
  329. self.points[key] = value