mot_error_visualize.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. from typing import Union
  4. try:
  5. import seaborn as sns
  6. except ImportError:
  7. sns = None
  8. import cv2
  9. import matplotlib.pyplot as plt
  10. import mmcv
  11. import numpy as np
  12. from matplotlib.patches import Rectangle
  13. from mmengine.utils import mkdir_or_exist
  14. def imshow_mot_errors(*args, backend: str = 'cv2', **kwargs):
  15. """Show the wrong tracks on the input image.
  16. Args:
  17. backend (str, optional): Backend of visualization.
  18. Defaults to 'cv2'.
  19. """
  20. if backend == 'cv2':
  21. return _cv2_show_wrong_tracks(*args, **kwargs)
  22. elif backend == 'plt':
  23. return _plt_show_wrong_tracks(*args, **kwargs)
  24. else:
  25. raise NotImplementedError()
  26. def _cv2_show_wrong_tracks(img: Union[str, np.ndarray],
  27. bboxes: np.ndarray,
  28. ids: np.ndarray,
  29. error_types: np.ndarray,
  30. thickness: int = 2,
  31. font_scale: float = 0.4,
  32. text_width: int = 10,
  33. text_height: int = 15,
  34. show: bool = False,
  35. wait_time: int = 100,
  36. out_file: str = None) -> np.ndarray:
  37. """Show the wrong tracks with opencv.
  38. Args:
  39. img (str or ndarray): The image to be displayed.
  40. bboxes (ndarray): A ndarray of shape (k, 5).
  41. ids (ndarray): A ndarray of shape (k, ).
  42. error_types (ndarray): A ndarray of shape (k, ), where 0 denotes
  43. false positives, 1 denotes false negative and 2 denotes ID switch.
  44. thickness (int, optional): Thickness of lines.
  45. Defaults to 2.
  46. font_scale (float, optional): Font scale to draw id and score.
  47. Defaults to 0.4.
  48. text_width (int, optional): Width to draw id and score.
  49. Defaults to 10.
  50. text_height (int, optional): Height to draw id and score.
  51. Defaults to 15.
  52. show (bool, optional): Whether to show the image on the fly.
  53. Defaults to False.
  54. wait_time (int, optional): Value of waitKey param.
  55. Defaults to 100.
  56. out_file (str, optional): The filename to write the image.
  57. Defaults to None.
  58. Returns:
  59. ndarray: Visualized image.
  60. """
  61. if sns is None:
  62. raise ImportError('please run pip install seaborn')
  63. assert bboxes.ndim == 2, \
  64. f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
  65. assert ids.ndim == 1, \
  66. f' ids ndim should be 1, but its ndim is {ids.ndim}.'
  67. assert error_types.ndim == 1, \
  68. f' error_types ndim should be 1, but its ndim is {error_types.ndim}.'
  69. assert bboxes.shape[0] == ids.shape[0], \
  70. 'bboxes.shape[0] and ids.shape[0] should have the same length.'
  71. assert bboxes.shape[1] == 5, \
  72. f' bboxes.shape[1] should be 5, but its {bboxes.shape[1]}.'
  73. bbox_colors = sns.color_palette()
  74. # red, yellow, blue
  75. bbox_colors = [bbox_colors[3], bbox_colors[1], bbox_colors[0]]
  76. bbox_colors = [[int(255 * _c) for _c in bbox_color][::-1]
  77. for bbox_color in bbox_colors]
  78. if isinstance(img, str):
  79. img = mmcv.imread(img)
  80. else:
  81. assert img.ndim == 3
  82. img_shape = img.shape
  83. bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
  84. bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
  85. for bbox, error_type, id in zip(bboxes, error_types, ids):
  86. x1, y1, x2, y2 = bbox[:4].astype(np.int32)
  87. score = float(bbox[-1])
  88. # bbox
  89. bbox_color = bbox_colors[error_type]
  90. cv2.rectangle(img, (x1, y1), (x2, y2), bbox_color, thickness=thickness)
  91. # FN does not have id and score
  92. if error_type == 1:
  93. continue
  94. # score
  95. text = '{:.02f}'.format(score)
  96. width = (len(text) - 1) * text_width
  97. img[y1:y1 + text_height, x1:x1 + width, :] = bbox_color
  98. cv2.putText(
  99. img,
  100. text, (x1, y1 + text_height - 2),
  101. cv2.FONT_HERSHEY_COMPLEX,
  102. font_scale,
  103. color=(0, 0, 0))
  104. # id
  105. text = str(id)
  106. width = len(text) * text_width
  107. img[y1 + text_height:y1 + text_height * 2,
  108. x1:x1 + width, :] = bbox_color
  109. cv2.putText(
  110. img,
  111. str(id), (x1, y1 + text_height * 2 - 2),
  112. cv2.FONT_HERSHEY_COMPLEX,
  113. font_scale,
  114. color=(0, 0, 0))
  115. if show:
  116. mmcv.imshow(img, wait_time=wait_time)
  117. if out_file is not None:
  118. mmcv.imwrite(img, out_file)
  119. return img
  120. def _plt_show_wrong_tracks(img: Union[str, np.ndarray],
  121. bboxes: np.ndarray,
  122. ids: np.ndarray,
  123. error_types: np.ndarray,
  124. thickness: float = 0.1,
  125. font_scale: float = 3.0,
  126. text_width: int = 8,
  127. text_height: int = 13,
  128. show: bool = False,
  129. wait_time: int = 100,
  130. out_file: str = None) -> np.ndarray:
  131. """Show the wrong tracks with matplotlib.
  132. Args:
  133. img (str or ndarray): The image to be displayed.
  134. bboxes (ndarray): A ndarray of shape (k, 5).
  135. ids (ndarray): A ndarray of shape (k, ).
  136. error_types (ndarray): A ndarray of shape (k, ), where 0 denotes
  137. false positives, 1 denotes false negative and 2 denotes ID switch.
  138. thickness (float, optional): Thickness of lines.
  139. Defaults to 0.1.
  140. font_scale (float, optional): Font scale to draw id and score.
  141. Defaults to 3.0.
  142. text_width (int, optional): Width to draw id and score.
  143. Defaults to 8.
  144. text_height (int, optional): Height to draw id and score.
  145. Defaults to 13.
  146. show (bool, optional): Whether to show the image on the fly.
  147. Defaults to False.
  148. wait_time (int, optional): Value of waitKey param.
  149. Defaults to 100.
  150. out_file (str, optional): The filename to write the image.
  151. Defaults to None.
  152. Returns:
  153. ndarray: Original image.
  154. """
  155. assert bboxes.ndim == 2, \
  156. f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
  157. assert ids.ndim == 1, \
  158. f' ids ndim should be 1, but its ndim is {ids.ndim}.'
  159. assert error_types.ndim == 1, \
  160. f' error_types ndim should be 1, but its ndim is {error_types.ndim}.'
  161. assert bboxes.shape[0] == ids.shape[0], \
  162. 'bboxes.shape[0] and ids.shape[0] should have the same length.'
  163. assert bboxes.shape[1] == 5, \
  164. f' bboxes.shape[1] should be 5, but its {bboxes.shape[1]}.'
  165. bbox_colors = sns.color_palette()
  166. # red, yellow, blue
  167. bbox_colors = [bbox_colors[3], bbox_colors[1], bbox_colors[0]]
  168. if isinstance(img, str):
  169. img = plt.imread(img)
  170. else:
  171. assert img.ndim == 3
  172. img = mmcv.bgr2rgb(img)
  173. img_shape = img.shape
  174. bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
  175. bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
  176. plt.imshow(img)
  177. plt.gca().set_axis_off()
  178. plt.autoscale(False)
  179. plt.subplots_adjust(
  180. top=1, bottom=0, right=1, left=0, hspace=None, wspace=None)
  181. plt.margins(0, 0)
  182. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  183. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  184. plt.rcParams['figure.figsize'] = img_shape[1], img_shape[0]
  185. for bbox, error_type, id in zip(bboxes, error_types, ids):
  186. x1, y1, x2, y2, score = bbox
  187. w, h = int(x2 - x1), int(y2 - y1)
  188. left_top = (int(x1), int(y1))
  189. # bbox
  190. plt.gca().add_patch(
  191. Rectangle(
  192. left_top,
  193. w,
  194. h,
  195. thickness,
  196. edgecolor=bbox_colors[error_type],
  197. facecolor='none'))
  198. # FN does not have id and score
  199. if error_type == 1:
  200. continue
  201. # score
  202. text = '{:.02f}'.format(score)
  203. width = len(text) * text_width
  204. plt.gca().add_patch(
  205. Rectangle((left_top[0], left_top[1]),
  206. width,
  207. text_height,
  208. thickness,
  209. edgecolor=bbox_colors[error_type],
  210. facecolor=bbox_colors[error_type]))
  211. plt.text(
  212. left_top[0],
  213. left_top[1] + text_height + 2,
  214. text,
  215. fontsize=font_scale)
  216. # id
  217. text = str(id)
  218. width = len(text) * text_width
  219. plt.gca().add_patch(
  220. Rectangle((left_top[0], left_top[1] + text_height + 1),
  221. width,
  222. text_height,
  223. thickness,
  224. edgecolor=bbox_colors[error_type],
  225. facecolor=bbox_colors[error_type]))
  226. plt.text(
  227. left_top[0],
  228. left_top[1] + 2 * (text_height + 1),
  229. text,
  230. fontsize=font_scale)
  231. if out_file is not None:
  232. mkdir_or_exist(osp.abspath(osp.dirname(out_file)))
  233. plt.savefig(out_file, dpi=300, bbox_inches='tight', pad_inches=0.0)
  234. if show:
  235. plt.draw()
  236. plt.pause(wait_time / 1000.)
  237. plt.clf()
  238. return img