label_dialog.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import re
  2. from qtpy import QT_VERSION
  3. from qtpy import QtCore
  4. from qtpy import QtGui
  5. from qtpy import QtWidgets
  6. QT5 = QT_VERSION[0] == '5' # NOQA
  7. from labelme.logger import logger
  8. import labelme.utils
  9. # TODO(unknown):
  10. # - Calculate optimal position so as not to go out of screen area.
  11. class LabelQLineEdit(QtWidgets.QLineEdit):
  12. def setListWidget(self, list_widget):
  13. self.list_widget = list_widget
  14. def keyPressEvent(self, e):
  15. if e.key() in [QtCore.Qt.Key_Up, QtCore.Qt.Key_Down]:
  16. self.list_widget.keyPressEvent(e)
  17. else:
  18. super(LabelQLineEdit, self).keyPressEvent(e)
  19. class LabelDialog(QtWidgets.QDialog):
  20. def __init__(self, text="Enter object label", parent=None, labels=None,
  21. sort_labels=True, show_text_field=True,
  22. completion='startswith', fit_to_content=None, flags=None):
  23. if fit_to_content is None:
  24. fit_to_content = {'row': False, 'column': True}
  25. self._fit_to_content = fit_to_content
  26. super(LabelDialog, self).__init__(parent)
  27. self.edit = LabelQLineEdit()
  28. self.edit.setPlaceholderText(text)
  29. self.edit.setValidator(labelme.utils.labelValidator())
  30. self.edit.editingFinished.connect(self.postProcess)
  31. if flags:
  32. self.edit.textChanged.connect(self.updateFlags)
  33. self.edit_group_id = QtWidgets.QLineEdit()
  34. self.edit_group_id.setPlaceholderText('Group ID')
  35. self.edit_group_id.setValidator(
  36. QtGui.QRegExpValidator(QtCore.QRegExp(r'\d*'), None)
  37. )
  38. layout = QtWidgets.QVBoxLayout()
  39. if show_text_field:
  40. layout_edit = QtWidgets.QHBoxLayout()
  41. layout_edit.addWidget(self.edit, 6)
  42. layout_edit.addWidget(self.edit_group_id, 2)
  43. layout.addLayout(layout_edit)
  44. # buttons
  45. self.buttonBox = bb = QtWidgets.QDialogButtonBox(
  46. QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel,
  47. QtCore.Qt.Horizontal,
  48. self,
  49. )
  50. bb.button(bb.Ok).setIcon(labelme.utils.newIcon('done'))
  51. bb.button(bb.Cancel).setIcon(labelme.utils.newIcon('undo'))
  52. bb.accepted.connect(self.validate)
  53. bb.rejected.connect(self.reject)
  54. layout.addWidget(bb)
  55. # label_list
  56. self.labelList = QtWidgets.QListWidget()
  57. if self._fit_to_content['row']:
  58. self.labelList.setHorizontalScrollBarPolicy(
  59. QtCore.Qt.ScrollBarAlwaysOff
  60. )
  61. if self._fit_to_content['column']:
  62. self.labelList.setVerticalScrollBarPolicy(
  63. QtCore.Qt.ScrollBarAlwaysOff
  64. )
  65. self._sort_labels = sort_labels
  66. if labels:
  67. self.labelList.addItems(labels)
  68. if self._sort_labels:
  69. self.labelList.sortItems()
  70. else:
  71. self.labelList.setDragDropMode(
  72. QtWidgets.QAbstractItemView.InternalMove)
  73. self.labelList.currentItemChanged.connect(self.labelSelected)
  74. self.labelList.itemDoubleClicked.connect(self.labelDoubleClicked)
  75. self.edit.setListWidget(self.labelList)
  76. layout.addWidget(self.labelList)
  77. # label_flags
  78. if flags is None:
  79. flags = {}
  80. self._flags = flags
  81. self.flagsLayout = QtWidgets.QVBoxLayout()
  82. self.resetFlags()
  83. layout.addItem(self.flagsLayout)
  84. self.edit.textChanged.connect(self.updateFlags)
  85. self.setLayout(layout)
  86. # completion
  87. completer = QtWidgets.QCompleter()
  88. if not QT5 and completion != 'startswith':
  89. logger.warn(
  90. "completion other than 'startswith' is only "
  91. "supported with Qt5. Using 'startswith'"
  92. )
  93. completion = 'startswith'
  94. if completion == 'startswith':
  95. completer.setCompletionMode(QtWidgets.QCompleter.InlineCompletion)
  96. # Default settings.
  97. # completer.setFilterMode(QtCore.Qt.MatchStartsWith)
  98. elif completion == 'contains':
  99. completer.setCompletionMode(QtWidgets.QCompleter.PopupCompletion)
  100. completer.setFilterMode(QtCore.Qt.MatchContains)
  101. else:
  102. raise ValueError('Unsupported completion: {}'.format(completion))
  103. completer.setModel(self.labelList.model())
  104. self.edit.setCompleter(completer)
  105. def addLabelHistory(self, label):
  106. if self.labelList.findItems(label, QtCore.Qt.MatchExactly):
  107. return
  108. self.labelList.addItem(label)
  109. if self._sort_labels:
  110. self.labelList.sortItems()
  111. def labelSelected(self, item):
  112. self.edit.setText(item.text())
  113. def validate(self):
  114. text = self.edit.text()
  115. if hasattr(text, 'strip'):
  116. text = text.strip()
  117. else:
  118. text = text.trimmed()
  119. if text:
  120. self.accept()
  121. def labelDoubleClicked(self, item):
  122. self.validate()
  123. def postProcess(self):
  124. text = self.edit.text()
  125. if hasattr(text, 'strip'):
  126. text = text.strip()
  127. else:
  128. text = text.trimmed()
  129. self.edit.setText(text)
  130. def updateFlags(self, label_new):
  131. # keep state of shared flags
  132. flags_old = self.getFlags()
  133. flags_new = {}
  134. for pattern, keys in self._flags.items():
  135. if re.match(pattern, label_new):
  136. for key in keys:
  137. flags_new[key] = flags_old.get(key, False)
  138. self.setFlags(flags_new)
  139. def deleteFlags(self):
  140. for i in reversed(range(self.flagsLayout.count())):
  141. item = self.flagsLayout.itemAt(i).widget()
  142. self.flagsLayout.removeWidget(item)
  143. item.setParent(None)
  144. def resetFlags(self, label=''):
  145. flags = {}
  146. for pattern, keys in self._flags.items():
  147. if re.match(pattern, label):
  148. for key in keys:
  149. flags[key] = False
  150. self.setFlags(flags)
  151. def setFlags(self, flags):
  152. self.deleteFlags()
  153. for key in flags:
  154. item = QtWidgets.QCheckBox(key, self)
  155. item.setChecked(flags[key])
  156. self.flagsLayout.addWidget(item)
  157. item.show()
  158. def getFlags(self):
  159. flags = {}
  160. for i in range(self.flagsLayout.count()):
  161. item = self.flagsLayout.itemAt(i).widget()
  162. flags[item.text()] = item.isChecked()
  163. return flags
  164. def getGroupId(self):
  165. group_id = self.edit_group_id.text()
  166. if group_id:
  167. return int(group_id)
  168. return None
  169. def popUp(self, text=None, move=True, flags=None, group_id=None):
  170. if self._fit_to_content['row']:
  171. self.labelList.setMinimumHeight(
  172. self.labelList.sizeHintForRow(0) * self.labelList.count() + 2
  173. )
  174. if self._fit_to_content['column']:
  175. self.labelList.setMinimumWidth(
  176. self.labelList.sizeHintForColumn(0) + 2
  177. )
  178. # if text is None, the previous label in self.edit is kept
  179. if text is None:
  180. text = self.edit.text()
  181. if flags:
  182. self.setFlags(flags)
  183. else:
  184. self.resetFlags(text)
  185. self.edit.setText(text)
  186. self.edit.setSelection(0, len(text))
  187. if group_id is None:
  188. self.edit_group_id.clear()
  189. else:
  190. self.edit_group_id.setText(str(group_id))
  191. items = self.labelList.findItems(text, QtCore.Qt.MatchFixedString)
  192. if items:
  193. if len(items) != 1:
  194. logger.warning("Label list has duplicate '{}'".format(text))
  195. self.labelList.setCurrentItem(items[0])
  196. row = self.labelList.row(items[0])
  197. self.edit.completer().setCurrentRow(row)
  198. self.edit.setFocus(QtCore.Qt.PopupFocusReason)
  199. if move:
  200. self.move(QtGui.QCursor.pos())
  201. if self.exec_():
  202. return self.edit.text(), self.getFlags(), self.getGroupId()
  203. else:
  204. return None, None, None