123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import os.path as osp
- from typing import Union
- try:
- import seaborn as sns
- except ImportError:
- sns = None
- import cv2
- import matplotlib.pyplot as plt
- import mmcv
- import numpy as np
- from matplotlib.patches import Rectangle
- from mmengine.utils import mkdir_or_exist
- def imshow_mot_errors(*args, backend: str = 'cv2', **kwargs):
- """Show the wrong tracks on the input image.
- Args:
- backend (str, optional): Backend of visualization.
- Defaults to 'cv2'.
- """
- if backend == 'cv2':
- return _cv2_show_wrong_tracks(*args, **kwargs)
- elif backend == 'plt':
- return _plt_show_wrong_tracks(*args, **kwargs)
- else:
- raise NotImplementedError()
- def _cv2_show_wrong_tracks(img: Union[str, np.ndarray],
- bboxes: np.ndarray,
- ids: np.ndarray,
- error_types: np.ndarray,
- thickness: int = 2,
- font_scale: float = 0.4,
- text_width: int = 10,
- text_height: int = 15,
- show: bool = False,
- wait_time: int = 100,
- out_file: str = None) -> np.ndarray:
- """Show the wrong tracks with opencv.
- Args:
- img (str or ndarray): The image to be displayed.
- bboxes (ndarray): A ndarray of shape (k, 5).
- ids (ndarray): A ndarray of shape (k, ).
- error_types (ndarray): A ndarray of shape (k, ), where 0 denotes
- false positives, 1 denotes false negative and 2 denotes ID switch.
- thickness (int, optional): Thickness of lines.
- Defaults to 2.
- font_scale (float, optional): Font scale to draw id and score.
- Defaults to 0.4.
- text_width (int, optional): Width to draw id and score.
- Defaults to 10.
- text_height (int, optional): Height to draw id and score.
- Defaults to 15.
- show (bool, optional): Whether to show the image on the fly.
- Defaults to False.
- wait_time (int, optional): Value of waitKey param.
- Defaults to 100.
- out_file (str, optional): The filename to write the image.
- Defaults to None.
- Returns:
- ndarray: Visualized image.
- """
- if sns is None:
- raise ImportError('please run pip install seaborn')
- assert bboxes.ndim == 2, \
- f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
- assert ids.ndim == 1, \
- f' ids ndim should be 1, but its ndim is {ids.ndim}.'
- assert error_types.ndim == 1, \
- f' error_types ndim should be 1, but its ndim is {error_types.ndim}.'
- assert bboxes.shape[0] == ids.shape[0], \
- 'bboxes.shape[0] and ids.shape[0] should have the same length.'
- assert bboxes.shape[1] == 5, \
- f' bboxes.shape[1] should be 5, but its {bboxes.shape[1]}.'
- bbox_colors = sns.color_palette()
- # red, yellow, blue
- bbox_colors = [bbox_colors[3], bbox_colors[1], bbox_colors[0]]
- bbox_colors = [[int(255 * _c) for _c in bbox_color][::-1]
- for bbox_color in bbox_colors]
- if isinstance(img, str):
- img = mmcv.imread(img)
- else:
- assert img.ndim == 3
- img_shape = img.shape
- bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
- bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
- for bbox, error_type, id in zip(bboxes, error_types, ids):
- x1, y1, x2, y2 = bbox[:4].astype(np.int32)
- score = float(bbox[-1])
- # bbox
- bbox_color = bbox_colors[error_type]
- cv2.rectangle(img, (x1, y1), (x2, y2), bbox_color, thickness=thickness)
- # FN does not have id and score
- if error_type == 1:
- continue
- # score
- text = '{:.02f}'.format(score)
- width = (len(text) - 1) * text_width
- img[y1:y1 + text_height, x1:x1 + width, :] = bbox_color
- cv2.putText(
- img,
- text, (x1, y1 + text_height - 2),
- cv2.FONT_HERSHEY_COMPLEX,
- font_scale,
- color=(0, 0, 0))
- # id
- text = str(id)
- width = len(text) * text_width
- img[y1 + text_height:y1 + text_height * 2,
- x1:x1 + width, :] = bbox_color
- cv2.putText(
- img,
- str(id), (x1, y1 + text_height * 2 - 2),
- cv2.FONT_HERSHEY_COMPLEX,
- font_scale,
- color=(0, 0, 0))
- if show:
- mmcv.imshow(img, wait_time=wait_time)
- if out_file is not None:
- mmcv.imwrite(img, out_file)
- return img
- def _plt_show_wrong_tracks(img: Union[str, np.ndarray],
- bboxes: np.ndarray,
- ids: np.ndarray,
- error_types: np.ndarray,
- thickness: float = 0.1,
- font_scale: float = 3.0,
- text_width: int = 8,
- text_height: int = 13,
- show: bool = False,
- wait_time: int = 100,
- out_file: str = None) -> np.ndarray:
- """Show the wrong tracks with matplotlib.
- Args:
- img (str or ndarray): The image to be displayed.
- bboxes (ndarray): A ndarray of shape (k, 5).
- ids (ndarray): A ndarray of shape (k, ).
- error_types (ndarray): A ndarray of shape (k, ), where 0 denotes
- false positives, 1 denotes false negative and 2 denotes ID switch.
- thickness (float, optional): Thickness of lines.
- Defaults to 0.1.
- font_scale (float, optional): Font scale to draw id and score.
- Defaults to 3.0.
- text_width (int, optional): Width to draw id and score.
- Defaults to 8.
- text_height (int, optional): Height to draw id and score.
- Defaults to 13.
- show (bool, optional): Whether to show the image on the fly.
- Defaults to False.
- wait_time (int, optional): Value of waitKey param.
- Defaults to 100.
- out_file (str, optional): The filename to write the image.
- Defaults to None.
- Returns:
- ndarray: Original image.
- """
- assert bboxes.ndim == 2, \
- f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
- assert ids.ndim == 1, \
- f' ids ndim should be 1, but its ndim is {ids.ndim}.'
- assert error_types.ndim == 1, \
- f' error_types ndim should be 1, but its ndim is {error_types.ndim}.'
- assert bboxes.shape[0] == ids.shape[0], \
- 'bboxes.shape[0] and ids.shape[0] should have the same length.'
- assert bboxes.shape[1] == 5, \
- f' bboxes.shape[1] should be 5, but its {bboxes.shape[1]}.'
- bbox_colors = sns.color_palette()
- # red, yellow, blue
- bbox_colors = [bbox_colors[3], bbox_colors[1], bbox_colors[0]]
- if isinstance(img, str):
- img = plt.imread(img)
- else:
- assert img.ndim == 3
- img = mmcv.bgr2rgb(img)
- img_shape = img.shape
- bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
- bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
- plt.imshow(img)
- plt.gca().set_axis_off()
- plt.autoscale(False)
- plt.subplots_adjust(
- top=1, bottom=0, right=1, left=0, hspace=None, wspace=None)
- plt.margins(0, 0)
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
- plt.rcParams['figure.figsize'] = img_shape[1], img_shape[0]
- for bbox, error_type, id in zip(bboxes, error_types, ids):
- x1, y1, x2, y2, score = bbox
- w, h = int(x2 - x1), int(y2 - y1)
- left_top = (int(x1), int(y1))
- # bbox
- plt.gca().add_patch(
- Rectangle(
- left_top,
- w,
- h,
- thickness,
- edgecolor=bbox_colors[error_type],
- facecolor='none'))
- # FN does not have id and score
- if error_type == 1:
- continue
- # score
- text = '{:.02f}'.format(score)
- width = len(text) * text_width
- plt.gca().add_patch(
- Rectangle((left_top[0], left_top[1]),
- width,
- text_height,
- thickness,
- edgecolor=bbox_colors[error_type],
- facecolor=bbox_colors[error_type]))
- plt.text(
- left_top[0],
- left_top[1] + text_height + 2,
- text,
- fontsize=font_scale)
- # id
- text = str(id)
- width = len(text) * text_width
- plt.gca().add_patch(
- Rectangle((left_top[0], left_top[1] + text_height + 1),
- width,
- text_height,
- thickness,
- edgecolor=bbox_colors[error_type],
- facecolor=bbox_colors[error_type]))
- plt.text(
- left_top[0],
- left_top[1] + 2 * (text_height + 1),
- text,
- fontsize=font_scale)
- if out_file is not None:
- mkdir_or_exist(osp.abspath(osp.dirname(out_file)))
- plt.savefig(out_file, dpi=300, bbox_inches='tight', pad_inches=0.0)
- if show:
- plt.draw()
- plt.pause(wait_time / 1000.)
- plt.clf()
- return img
|