label_dialog.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import re
  2. from qtpy import QT_VERSION
  3. from qtpy import QtCore
  4. from qtpy import QtGui
  5. from qtpy import QtWidgets
  6. import labelme.utils
  7. from labelme.logger import logger
  8. QT5 = QT_VERSION[0] == "5"
  9. # TODO(unknown):
  10. # - Calculate optimal position so as not to go out of screen area.
  11. class LabelQLineEdit(QtWidgets.QLineEdit):
  12. def setListWidget(self, list_widget):
  13. self.list_widget = list_widget
  14. def keyPressEvent(self, e):
  15. if e.key() in [QtCore.Qt.Key_Up, QtCore.Qt.Key_Down]:
  16. self.list_widget.keyPressEvent(e)
  17. else:
  18. super(LabelQLineEdit, self).keyPressEvent(e)
  19. class LabelDialog(QtWidgets.QDialog):
  20. def __init__(
  21. self,
  22. text="Enter object label",
  23. parent=None,
  24. labels=None,
  25. sort_labels=True,
  26. show_text_field=True,
  27. completion="startswith",
  28. fit_to_content=None,
  29. flags=None,
  30. ):
  31. if fit_to_content is None:
  32. fit_to_content = {"row": False, "column": True}
  33. self._fit_to_content = fit_to_content
  34. super(LabelDialog, self).__init__(parent)
  35. self.edit = LabelQLineEdit()
  36. self.edit.setPlaceholderText(text)
  37. self.edit.setValidator(labelme.utils.labelValidator())
  38. self.edit.editingFinished.connect(self.postProcess)
  39. self.dragging=False
  40. if flags:
  41. self.edit.textChanged.connect(self.updateFlags)
  42. self.edit_group_id = QtWidgets.QLineEdit()
  43. self.edit_group_id.setPlaceholderText("Group ID")
  44. self.edit_group_id.setValidator(
  45. QtGui.QRegExpValidator(QtCore.QRegExp(r"\d*"), None)
  46. )
  47. layout = QtWidgets.QVBoxLayout()
  48. if show_text_field:
  49. layout_edit = QtWidgets.QHBoxLayout()
  50. layout_edit.addWidget(self.edit, 6)
  51. layout_edit.addWidget(self.edit_group_id, 2)
  52. layout.addLayout(layout_edit)
  53. # buttons
  54. self.buttonBox = bb = QtWidgets.QDialogButtonBox(
  55. QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel,
  56. QtCore.Qt.Horizontal,
  57. self,
  58. )
  59. bb.button(bb.Ok).setIcon(labelme.utils.newIcon("done"))
  60. bb.button(bb.Cancel).setIcon(labelme.utils.newIcon("undo"))
  61. bb.accepted.connect(self.validate)
  62. bb.rejected.connect(self.reject)
  63. layout.addWidget(bb)
  64. # label_list
  65. self.labelList = QtWidgets.QListWidget()
  66. if self._fit_to_content["row"]:
  67. self.labelList.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
  68. if self._fit_to_content["column"]:
  69. self.labelList.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
  70. self._sort_labels = sort_labels
  71. if labels:
  72. self.labelList.addItems(labels)
  73. if self._sort_labels:
  74. self.labelList.sortItems()
  75. else:
  76. self.labelList.setDragDropMode(QtWidgets.QAbstractItemView.InternalMove)
  77. self.labelList.currentItemChanged.connect(self.labelSelected)
  78. self.labelList.itemDoubleClicked.connect(self.labelDoubleClicked)
  79. self.labelList.setFixedHeight(150)
  80. self.edit.setListWidget(self.labelList)
  81. layout.addWidget(self.labelList)
  82. # label_flags
  83. if flags is None:
  84. flags = {}
  85. self._flags = flags
  86. self.flagsLayout = QtWidgets.QVBoxLayout()
  87. self.resetFlags()
  88. layout.addItem(self.flagsLayout)
  89. self.edit.textChanged.connect(self.updateFlags)
  90. # text edit
  91. self.editDescription = QtWidgets.QTextEdit()
  92. self.editDescription.setPlaceholderText("Label description")
  93. self.editDescription.setFixedHeight(50)
  94. layout.addWidget(self.editDescription)
  95. self.setLayout(layout)
  96. # completion
  97. completer = QtWidgets.QCompleter()
  98. if not QT5 and completion != "startswith":
  99. logger.warn(
  100. "completion other than 'startswith' is only "
  101. "supported with Qt5. Using 'startswith'"
  102. )
  103. completion = "startswith"
  104. if completion == "startswith":
  105. completer.setCompletionMode(QtWidgets.QCompleter.InlineCompletion)
  106. # Default settings.
  107. # completer.setFilterMode(QtCore.Qt.MatchStartsWith)
  108. elif completion == "contains":
  109. completer.setCompletionMode(QtWidgets.QCompleter.PopupCompletion)
  110. completer.setFilterMode(QtCore.Qt.MatchContains)
  111. else:
  112. raise ValueError("Unsupported completion: {}".format(completion))
  113. completer.setModel(self.labelList.model())
  114. self.edit.setCompleter(completer)
  115. def addLabelHistory(self, label):
  116. if self.labelList.findItems(label, QtCore.Qt.MatchExactly):
  117. return
  118. self.labelList.addItem(label)
  119. if self._sort_labels:
  120. self.labelList.sortItems()
  121. def labelSelected(self, item):
  122. self.edit.setText(item.text())
  123. def validate(self):
  124. text = self.edit.text()
  125. if hasattr(text, "strip"):
  126. text = text.strip()
  127. else:
  128. text = text.trimmed()
  129. if text:
  130. self.accept()
  131. def labelDoubleClicked(self, item):
  132. self.validate()
  133. def postProcess(self):
  134. text = self.edit.text()
  135. if hasattr(text, "strip"):
  136. text = text.strip()
  137. else:
  138. text = text.trimmed()
  139. self.edit.setText(text)
  140. def updateFlags(self, label_new):
  141. # keep state of shared flags
  142. flags_old = self.getFlags()
  143. flags_new = {}
  144. for pattern, keys in self._flags.items():
  145. if re.match(pattern, label_new):
  146. for key in keys:
  147. flags_new[key] = flags_old.get(key, False)
  148. self.setFlags(flags_new)
  149. def deleteFlags(self):
  150. for i in reversed(range(self.flagsLayout.count())):
  151. item = self.flagsLayout.itemAt(i).widget()
  152. self.flagsLayout.removeWidget(item)
  153. item.setParent(None)
  154. def resetFlags(self, label=""):
  155. flags = {}
  156. for pattern, keys in self._flags.items():
  157. if re.match(pattern, label):
  158. for key in keys:
  159. flags[key] = False
  160. self.setFlags(flags)
  161. def setFlags(self, flags):
  162. self.deleteFlags()
  163. for key in flags:
  164. item = QtWidgets.QCheckBox(key, self)
  165. item.setChecked(flags[key])
  166. self.flagsLayout.addWidget(item)
  167. item.show()
  168. def getFlags(self):
  169. flags = {}
  170. for i in range(self.flagsLayout.count()):
  171. item = self.flagsLayout.itemAt(i).widget()
  172. flags[item.text()] = item.isChecked()
  173. return flags
  174. def getGroupId(self):
  175. group_id = self.edit_group_id.text()
  176. if group_id:
  177. return int(group_id)
  178. return None
  179. def popUp(self, text=None, move=True, flags=None, group_id=None, description=None):
  180. self.show()
  181. if self._fit_to_content["row"]:
  182. self.labelList.setMinimumHeight(
  183. self.labelList.sizeHintForRow(0) * self.labelList.count() + 2
  184. )
  185. if self._fit_to_content["column"]:
  186. self.labelList.setMinimumWidth(self.labelList.sizeHintForColumn(0) + 2)
  187. # if text is None, the previous label in self.edit is kept
  188. if text is None:
  189. text = self.edit.text()
  190. # description is always initialized by empty text c.f., self.edit.text
  191. if description is None:
  192. description = ""
  193. self.editDescription.setPlainText(description)
  194. if flags:
  195. self.setFlags(flags)
  196. else:
  197. self.resetFlags(text)
  198. self.edit.setText(text)
  199. self.edit.setSelection(0, len(text))
  200. if group_id is None:
  201. self.edit_group_id.clear()
  202. else:
  203. self.edit_group_id.setText(str(group_id))
  204. items = self.labelList.findItems(text, QtCore.Qt.MatchFixedString)
  205. if items:
  206. if len(items) != 1:
  207. logger.warning("Label list has duplicate '{}'".format(text))
  208. self.labelList.setCurrentItem(items[0])
  209. row = self.labelList.row(items[0])
  210. self.edit.completer().setCurrentRow(row)
  211. self.edit.setFocus(QtCore.Qt.PopupFocusReason)
  212. if move:
  213. self.move(QtGui.QCursor.pos())
  214. if self.exec_():
  215. return (
  216. self.edit.text(),
  217. self.getFlags(),
  218. self.getGroupId(),
  219. self.editDescription.toPlainText(),
  220. )
  221. else:
  222. return None, None, None, None
  223. def mousePressEvent(self, event):
  224. if event.button() == QtCore.Qt.LeftButton:
  225. self.dragging = True
  226. self.startPos = event.globalPos() - self.frameGeometry().topLeft()
  227. event.accept()
  228. def mouseMoveEvent(self, event):
  229. if self.dragging:
  230. self.move(event.globalPos() - self.startPos)
  231. event.accept()
  232. def mouseReleaseEvent(self, event):
  233. if event.button() == QtCore.Qt.LeftButton:
  234. self.dragging = False
  235. event.accept()