local_visualizer.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Optional, Tuple, Union
  3. import cv2
  4. import mmcv
  5. import numpy as np
  6. try:
  7. import seaborn as sns
  8. except ImportError:
  9. sns = None
  10. import torch
  11. from mmengine.dist import master_only
  12. from mmengine.structures import InstanceData, PixelData
  13. from mmengine.visualization import Visualizer
  14. from ..evaluation import INSTANCE_OFFSET
  15. from ..registry import VISUALIZERS
  16. from ..structures import DetDataSample
  17. from ..structures.mask import BitmapMasks, PolygonMasks, bitmap_to_polygon
  18. from .palette import _get_adaptive_scales, get_palette, jitter_color
  19. @VISUALIZERS.register_module()
  20. class DetLocalVisualizer(Visualizer):
  21. """MMDetection Local Visualizer.
  22. Args:
  23. name (str): Name of the instance. Defaults to 'visualizer'.
  24. image (np.ndarray, optional): the origin image to draw. The format
  25. should be RGB. Defaults to None.
  26. vis_backends (list, optional): Visual backend config list.
  27. Defaults to None.
  28. save_dir (str, optional): Save file dir for all storage backends.
  29. If it is None, the backend storage will not save any data.
  30. bbox_color (str, tuple(int), optional): Color of bbox lines.
  31. The tuple of color should be in BGR order. Defaults to None.
  32. text_color (str, tuple(int), optional): Color of texts.
  33. The tuple of color should be in BGR order.
  34. Defaults to (200, 200, 200).
  35. mask_color (str, tuple(int), optional): Color of masks.
  36. The tuple of color should be in BGR order.
  37. Defaults to None.
  38. line_width (int, float): The linewidth of lines.
  39. Defaults to 3.
  40. alpha (int, float): The transparency of bboxes or mask.
  41. Defaults to 0.8.
  42. Examples:
  43. >>> import numpy as np
  44. >>> import torch
  45. >>> from mmengine.structures import InstanceData
  46. >>> from mmdet.structures import DetDataSample
  47. >>> from mmdet.visualization import DetLocalVisualizer
  48. >>> det_local_visualizer = DetLocalVisualizer()
  49. >>> image = np.random.randint(0, 256,
  50. ... size=(10, 12, 3)).astype('uint8')
  51. >>> gt_instances = InstanceData()
  52. >>> gt_instances.bboxes = torch.Tensor([[1, 2, 2, 5]])
  53. >>> gt_instances.labels = torch.randint(0, 2, (1,))
  54. >>> gt_det_data_sample = DetDataSample()
  55. >>> gt_det_data_sample.gt_instances = gt_instances
  56. >>> det_local_visualizer.add_datasample('image', image,
  57. ... gt_det_data_sample)
  58. >>> det_local_visualizer.add_datasample(
  59. ... 'image', image, gt_det_data_sample,
  60. ... out_file='out_file.jpg')
  61. >>> det_local_visualizer.add_datasample(
  62. ... 'image', image, gt_det_data_sample,
  63. ... show=True)
  64. >>> pred_instances = InstanceData()
  65. >>> pred_instances.bboxes = torch.Tensor([[2, 4, 4, 8]])
  66. >>> pred_instances.labels = torch.randint(0, 2, (1,))
  67. >>> pred_det_data_sample = DetDataSample()
  68. >>> pred_det_data_sample.pred_instances = pred_instances
  69. >>> det_local_visualizer.add_datasample('image', image,
  70. ... gt_det_data_sample,
  71. ... pred_det_data_sample)
  72. """
  73. def __init__(self,
  74. name: str = 'visualizer',
  75. image: Optional[np.ndarray] = None,
  76. vis_backends: Optional[Dict] = None,
  77. save_dir: Optional[str] = None,
  78. bbox_color: Optional[Union[str, Tuple[int]]] = None,
  79. text_color: Optional[Union[str,
  80. Tuple[int]]] = (200, 200, 200),
  81. mask_color: Optional[Union[str, Tuple[int]]] = None,
  82. line_width: Union[int, float] = 3,
  83. alpha: float = 0.8) -> None:
  84. super().__init__(
  85. name=name,
  86. image=image,
  87. vis_backends=vis_backends,
  88. save_dir=save_dir)
  89. self.bbox_color = bbox_color
  90. self.text_color = text_color
  91. self.mask_color = mask_color
  92. self.line_width = line_width
  93. self.alpha = alpha
  94. # Set default value. When calling
  95. # `DetLocalVisualizer().dataset_meta=xxx`,
  96. # it will override the default value.
  97. self.dataset_meta = {}
  98. def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'],
  99. classes: Optional[List[str]],
  100. palette: Optional[List[tuple]]) -> np.ndarray:
  101. """Draw instances of GT or prediction.
  102. Args:
  103. image (np.ndarray): The image to draw.
  104. instances (:obj:`InstanceData`): Data structure for
  105. instance-level annotations or predictions.
  106. classes (List[str], optional): Category information.
  107. palette (List[tuple], optional): Palette information
  108. corresponding to the category.
  109. Returns:
  110. np.ndarray: the drawn image which channel is RGB.
  111. """
  112. self.set_image(image)
  113. if 'bboxes' in instances and instances.bboxes.sum() > 0:
  114. bboxes = instances.bboxes
  115. labels = instances.labels
  116. max_label = int(max(labels) if len(labels) > 0 else 0)
  117. text_palette = get_palette(self.text_color, max_label + 1)
  118. text_colors = [text_palette[label] for label in labels]
  119. bbox_color = palette if self.bbox_color is None \
  120. else self.bbox_color
  121. bbox_palette = get_palette(bbox_color, max_label + 1)
  122. colors = [bbox_palette[label] for label in labels]
  123. self.draw_bboxes(
  124. bboxes,
  125. edge_colors=colors,
  126. alpha=self.alpha,
  127. line_widths=self.line_width)
  128. positions = bboxes[:, :2] + self.line_width
  129. areas = (bboxes[:, 3] - bboxes[:, 1]) * (
  130. bboxes[:, 2] - bboxes[:, 0])
  131. scales = _get_adaptive_scales(areas)
  132. for i, (pos, label) in enumerate(zip(positions, labels)):
  133. if 'label_names' in instances:
  134. label_text = instances.label_names[i]
  135. else:
  136. label_text = classes[
  137. label] if classes is not None else f'class {label}'
  138. if 'scores' in instances:
  139. score = round(float(instances.scores[i]) * 100, 1)
  140. label_text += f': {score}'
  141. self.draw_texts(
  142. label_text,
  143. pos,
  144. colors=text_colors[i],
  145. font_sizes=int(13 * scales[i]),
  146. bboxes=[{
  147. 'facecolor': 'black',
  148. 'alpha': 0.8,
  149. 'pad': 0.7,
  150. 'edgecolor': 'none'
  151. }])
  152. if 'masks' in instances:
  153. labels = instances.labels
  154. masks = instances.masks
  155. if isinstance(masks, torch.Tensor):
  156. masks = masks.numpy()
  157. elif isinstance(masks, (PolygonMasks, BitmapMasks)):
  158. masks = masks.to_ndarray()
  159. masks = masks.astype(bool)
  160. max_label = int(max(labels) if len(labels) > 0 else 0)
  161. mask_color = palette if self.mask_color is None \
  162. else self.mask_color
  163. mask_palette = get_palette(mask_color, max_label + 1)
  164. colors = [jitter_color(mask_palette[label]) for label in labels]
  165. text_palette = get_palette(self.text_color, max_label + 1)
  166. text_colors = [text_palette[label] for label in labels]
  167. polygons = []
  168. for i, mask in enumerate(masks):
  169. contours, _ = bitmap_to_polygon(mask)
  170. polygons.extend(contours)
  171. self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha)
  172. self.draw_binary_masks(masks, colors=colors, alphas=self.alpha)
  173. if len(labels) > 0 and \
  174. ('bboxes' not in instances or
  175. instances.bboxes.sum() == 0):
  176. # instances.bboxes.sum()==0 represent dummy bboxes.
  177. # A typical example of SOLO does not exist bbox branch.
  178. areas = []
  179. positions = []
  180. for mask in masks:
  181. _, _, stats, centroids = cv2.connectedComponentsWithStats(
  182. mask.astype(np.uint8), connectivity=8)
  183. if stats.shape[0] > 1:
  184. largest_id = np.argmax(stats[1:, -1]) + 1
  185. positions.append(centroids[largest_id])
  186. areas.append(stats[largest_id, -1])
  187. areas = np.stack(areas, axis=0)
  188. scales = _get_adaptive_scales(areas)
  189. for i, (pos, label) in enumerate(zip(positions, labels)):
  190. if 'label_names' in instances:
  191. label_text = instances.label_names[i]
  192. else:
  193. label_text = classes[
  194. label] if classes is not None else f'class {label}'
  195. if 'scores' in instances:
  196. score = round(float(instances.scores[i]) * 100, 1)
  197. label_text += f': {score}'
  198. self.draw_texts(
  199. label_text,
  200. pos,
  201. colors=text_colors[i],
  202. font_sizes=int(13 * scales[i]),
  203. horizontal_alignments='center',
  204. bboxes=[{
  205. 'facecolor': 'black',
  206. 'alpha': 0.8,
  207. 'pad': 0.7,
  208. 'edgecolor': 'none'
  209. }])
  210. return self.get_image()
  211. def _draw_panoptic_seg(self, image: np.ndarray,
  212. panoptic_seg: ['PixelData'],
  213. classes: Optional[List[str]],
  214. palette: Optional[List]) -> np.ndarray:
  215. """Draw panoptic seg of GT or prediction.
  216. Args:
  217. image (np.ndarray): The image to draw.
  218. panoptic_seg (:obj:`PixelData`): Data structure for
  219. pixel-level annotations or predictions.
  220. classes (List[str], optional): Category information.
  221. Returns:
  222. np.ndarray: the drawn image which channel is RGB.
  223. """
  224. # TODO: Is there a way to bypass?
  225. num_classes = len(classes)
  226. panoptic_seg_data = panoptic_seg.sem_seg[0]
  227. ids = np.unique(panoptic_seg_data)[::-1]
  228. if 'label_names' in panoptic_seg:
  229. # open set panoptic segmentation
  230. classes = panoptic_seg.metainfo['label_names']
  231. ignore_index = panoptic_seg.metainfo.get('ignore_index',
  232. len(classes))
  233. ids = ids[ids != ignore_index]
  234. else:
  235. # for VOID label
  236. ids = ids[ids != num_classes]
  237. labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
  238. segms = (panoptic_seg_data[None] == ids[:, None, None])
  239. max_label = int(max(labels) if len(labels) > 0 else 0)
  240. mask_color = palette if self.mask_color is None \
  241. else self.mask_color
  242. mask_palette = get_palette(mask_color, max_label + 1)
  243. colors = [mask_palette[label] for label in labels]
  244. self.set_image(image)
  245. # draw segm
  246. polygons = []
  247. for i, mask in enumerate(segms):
  248. contours, _ = bitmap_to_polygon(mask)
  249. polygons.extend(contours)
  250. self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha)
  251. self.draw_binary_masks(segms, colors=colors, alphas=self.alpha)
  252. # draw label
  253. areas = []
  254. positions = []
  255. for mask in segms:
  256. _, _, stats, centroids = cv2.connectedComponentsWithStats(
  257. mask.astype(np.uint8), connectivity=8)
  258. max_id = np.argmax(stats[1:, -1]) + 1
  259. positions.append(centroids[max_id])
  260. areas.append(stats[max_id, -1])
  261. areas = np.stack(areas, axis=0)
  262. scales = _get_adaptive_scales(areas)
  263. text_palette = get_palette(self.text_color, max_label + 1)
  264. text_colors = [text_palette[label] for label in labels]
  265. for i, (pos, label) in enumerate(zip(positions, labels)):
  266. label_text = classes[label]
  267. self.draw_texts(
  268. label_text,
  269. pos,
  270. colors=text_colors[i],
  271. font_sizes=int(13 * scales[i]),
  272. bboxes=[{
  273. 'facecolor': 'black',
  274. 'alpha': 0.8,
  275. 'pad': 0.7,
  276. 'edgecolor': 'none'
  277. }],
  278. horizontal_alignments='center')
  279. return self.get_image()
  280. def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
  281. classes: Optional[List],
  282. palette: Optional[List]) -> np.ndarray:
  283. """Draw semantic seg of GT or prediction.
  284. Args:
  285. image (np.ndarray): The image to draw.
  286. sem_seg (:obj:`PixelData`): Data structure for pixel-level
  287. annotations or predictions.
  288. classes (list, optional): Input classes for result rendering, as
  289. the prediction of segmentation model is a segment map with
  290. label indices, `classes` is a list which includes items
  291. responding to the label indices. If classes is not defined,
  292. visualizer will take `cityscapes` classes by default.
  293. Defaults to None.
  294. palette (list, optional): Input palette for result rendering, which
  295. is a list of color palette responding to the classes.
  296. Defaults to None.
  297. Returns:
  298. np.ndarray: the drawn image which channel is RGB.
  299. """
  300. sem_seg_data = sem_seg.sem_seg
  301. if isinstance(sem_seg_data, torch.Tensor):
  302. sem_seg_data = sem_seg_data.numpy()
  303. # 0 ~ num_class, the value 0 means background
  304. ids = np.unique(sem_seg_data)
  305. ignore_index = sem_seg.metainfo.get('ignore_index', 255)
  306. ids = ids[ids != ignore_index]
  307. if 'label_names' in sem_seg:
  308. # open set semseg
  309. label_names = sem_seg.metainfo['label_names']
  310. else:
  311. label_names = classes
  312. labels = np.array(ids, dtype=np.int64)
  313. colors = [palette[label] for label in labels]
  314. self.set_image(image)
  315. # draw semantic masks
  316. for i, (label, color) in enumerate(zip(labels, colors)):
  317. masks = sem_seg_data == label
  318. self.draw_binary_masks(masks, colors=[color], alphas=self.alpha)
  319. label_text = label_names[label]
  320. _, _, stats, centroids = cv2.connectedComponentsWithStats(
  321. masks[0].astype(np.uint8), connectivity=8)
  322. if stats.shape[0] > 1:
  323. largest_id = np.argmax(stats[1:, -1]) + 1
  324. centroids = centroids[largest_id]
  325. areas = stats[largest_id, -1]
  326. scales = _get_adaptive_scales(areas)
  327. self.draw_texts(
  328. label_text,
  329. centroids,
  330. colors=(255, 255, 255),
  331. font_sizes=int(13 * scales),
  332. horizontal_alignments='center',
  333. bboxes=[{
  334. 'facecolor': 'black',
  335. 'alpha': 0.8,
  336. 'pad': 0.7,
  337. 'edgecolor': 'none'
  338. }])
  339. return self.get_image()
  340. @master_only
  341. def add_datasample(
  342. self,
  343. name: str,
  344. image: np.ndarray,
  345. data_sample: Optional['DetDataSample'] = None,
  346. draw_gt: bool = False,
  347. draw_pred: bool = True,
  348. show: bool = False,
  349. wait_time: float = 0,
  350. # TODO: Supported in mmengine's Viusalizer.
  351. out_file: Optional[str] = None,
  352. pred_score_thr: float = 0.3,
  353. step: int = 0) -> None:
  354. """Draw datasample and save to all backends.
  355. - If GT and prediction are plotted at the same time, they are
  356. displayed in a stitched image where the left image is the
  357. ground truth and the right image is the prediction.
  358. - If ``show`` is True, all storage backends are ignored, and
  359. the images will be displayed in a local window.
  360. - If ``out_file`` is specified, the drawn image will be
  361. saved to ``out_file``. t is usually used when the display
  362. is not available.
  363. Args:
  364. name (str): The image identifier.
  365. image (np.ndarray): The image to draw.
  366. data_sample (:obj:`DetDataSample`, optional): A data
  367. sample that contain annotations and predictions.
  368. Defaults to None.
  369. draw_gt (bool): Whether to draw GT DetDataSample. Default to True.
  370. draw_pred (bool): Whether to draw Prediction DetDataSample.
  371. Defaults to True.
  372. show (bool): Whether to display the drawn image. Default to False.
  373. wait_time (float): The interval of show (s). Defaults to 0.
  374. out_file (str): Path to output file. Defaults to None.
  375. pred_score_thr (float): The threshold to visualize the bboxes
  376. and masks. Defaults to 0.3.
  377. step (int): Global step value to record. Defaults to 0.
  378. """
  379. image = image.clip(0, 255).astype(np.uint8)
  380. classes = self.dataset_meta.get('classes', None)
  381. palette = self.dataset_meta.get('palette', None)
  382. gt_img_data = None
  383. pred_img_data = None
  384. if data_sample is not None:
  385. data_sample = data_sample.cpu()
  386. if draw_gt and data_sample is not None:
  387. gt_img_data = image
  388. if 'gt_instances' in data_sample:
  389. gt_img_data = self._draw_instances(image,
  390. data_sample.gt_instances,
  391. classes, palette)
  392. if 'gt_sem_seg' in data_sample:
  393. gt_img_data = self._draw_sem_seg(gt_img_data,
  394. data_sample.gt_sem_seg,
  395. classes, palette)
  396. if 'gt_panoptic_seg' in data_sample:
  397. assert classes is not None, 'class information is ' \
  398. 'not provided when ' \
  399. 'visualizing panoptic ' \
  400. 'segmentation results.'
  401. gt_img_data = self._draw_panoptic_seg(
  402. gt_img_data, data_sample.gt_panoptic_seg, classes, palette)
  403. if draw_pred and data_sample is not None:
  404. pred_img_data = image
  405. if 'pred_instances' in data_sample:
  406. pred_instances = data_sample.pred_instances
  407. pred_instances = pred_instances[
  408. pred_instances.scores > pred_score_thr]
  409. pred_img_data = self._draw_instances(image, pred_instances,
  410. classes, palette)
  411. if 'pred_sem_seg' in data_sample:
  412. pred_img_data = self._draw_sem_seg(pred_img_data,
  413. data_sample.pred_sem_seg,
  414. classes, palette)
  415. if 'pred_panoptic_seg' in data_sample:
  416. assert classes is not None, 'class information is ' \
  417. 'not provided when ' \
  418. 'visualizing panoptic ' \
  419. 'segmentation results.'
  420. pred_img_data = self._draw_panoptic_seg(
  421. pred_img_data, data_sample.pred_panoptic_seg.numpy(),
  422. classes, palette)
  423. if gt_img_data is not None and pred_img_data is not None:
  424. drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
  425. elif gt_img_data is not None:
  426. drawn_img = gt_img_data
  427. elif pred_img_data is not None:
  428. drawn_img = pred_img_data
  429. else:
  430. # Display the original image directly if nothing is drawn.
  431. drawn_img = image
  432. # It is convenient for users to obtain the drawn image.
  433. # For example, the user wants to obtain the drawn image and
  434. # save it as a video during video inference.
  435. self.set_image(drawn_img)
  436. if show:
  437. self.show(drawn_img, win_name=name, wait_time=wait_time)
  438. if out_file is not None:
  439. mmcv.imwrite(drawn_img[..., ::-1], out_file)
  440. else:
  441. self.add_image(name, drawn_img, step)
  442. def random_color(seed):
  443. """Random a color according to the input seed."""
  444. if sns is None:
  445. raise RuntimeError('motmetrics is not installed,\
  446. please install it by: pip install seaborn')
  447. np.random.seed(seed)
  448. colors = sns.color_palette()
  449. color = colors[np.random.choice(range(len(colors)))]
  450. color = tuple([int(255 * c) for c in color])
  451. return color
  452. @VISUALIZERS.register_module()
  453. class TrackLocalVisualizer(Visualizer):
  454. """Tracking Local Visualizer for the MOT, VIS tasks.
  455. Args:
  456. name (str): Name of the instance. Defaults to 'visualizer'.
  457. image (np.ndarray, optional): the origin image to draw. The format
  458. should be RGB. Defaults to None.
  459. vis_backends (list, optional): Visual backend config list.
  460. Defaults to None.
  461. save_dir (str, optional): Save file dir for all storage backends.
  462. If it is None, the backend storage will not save any data.
  463. line_width (int, float): The linewidth of lines.
  464. Defaults to 3.
  465. alpha (int, float): The transparency of bboxes or mask.
  466. Defaults to 0.8.
  467. """
  468. def __init__(self,
  469. name: str = 'visualizer',
  470. image: Optional[np.ndarray] = None,
  471. vis_backends: Optional[Dict] = None,
  472. save_dir: Optional[str] = None,
  473. line_width: Union[int, float] = 3,
  474. alpha: float = 0.8) -> None:
  475. super().__init__(name, image, vis_backends, save_dir)
  476. self.line_width = line_width
  477. self.alpha = alpha
  478. # Set default value. When calling
  479. # `TrackLocalVisualizer().dataset_meta=xxx`,
  480. # it will override the default value.
  481. self.dataset_meta = {}
  482. def _draw_instances(self, image: np.ndarray,
  483. instances: InstanceData) -> np.ndarray:
  484. """Draw instances of GT or prediction.
  485. Args:
  486. image (np.ndarray): The image to draw.
  487. instances (:obj:`InstanceData`): Data structure for
  488. instance-level annotations or predictions.
  489. Returns:
  490. np.ndarray: the drawn image which channel is RGB.
  491. """
  492. self.set_image(image)
  493. classes = self.dataset_meta.get('classes', None)
  494. # get colors and texts
  495. # for the MOT and VIS tasks
  496. colors = [random_color(_id) for _id in instances.instances_id]
  497. categories = [
  498. classes[label] if classes is not None else f'cls{label}'
  499. for label in instances.labels
  500. ]
  501. if 'scores' in instances:
  502. texts = [
  503. f'{category_name}\n{instance_id} | {score:.2f}'
  504. for category_name, instance_id, score in zip(
  505. categories, instances.instances_id, instances.scores)
  506. ]
  507. else:
  508. texts = [
  509. f'{category_name}\n{instance_id}' for category_name,
  510. instance_id in zip(categories, instances.instances_id)
  511. ]
  512. # draw bboxes and texts
  513. if 'bboxes' in instances:
  514. # draw bboxes
  515. bboxes = instances.bboxes.clone()
  516. self.draw_bboxes(
  517. bboxes,
  518. edge_colors=colors,
  519. alpha=self.alpha,
  520. line_widths=self.line_width)
  521. # draw texts
  522. if texts is not None:
  523. positions = bboxes[:, :2] + self.line_width
  524. areas = (bboxes[:, 3] - bboxes[:, 1]) * (
  525. bboxes[:, 2] - bboxes[:, 0])
  526. scales = _get_adaptive_scales(areas.cpu().numpy())
  527. for i, pos in enumerate(positions):
  528. self.draw_texts(
  529. texts[i],
  530. pos,
  531. colors='black',
  532. font_sizes=int(13 * scales[i]),
  533. bboxes=[{
  534. 'facecolor': [c / 255 for c in colors[i]],
  535. 'alpha': 0.8,
  536. 'pad': 0.7,
  537. 'edgecolor': 'none'
  538. }])
  539. # draw masks
  540. if 'masks' in instances:
  541. masks = instances.masks
  542. polygons = []
  543. for i, mask in enumerate(masks):
  544. contours, _ = bitmap_to_polygon(mask)
  545. polygons.extend(contours)
  546. self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha)
  547. self.draw_binary_masks(masks, colors=colors, alphas=self.alpha)
  548. return self.get_image()
  549. @master_only
  550. def add_datasample(
  551. self,
  552. name: str,
  553. image: np.ndarray,
  554. data_sample: DetDataSample = None,
  555. draw_gt: bool = True,
  556. draw_pred: bool = True,
  557. show: bool = False,
  558. wait_time: int = 0,
  559. # TODO: Supported in mmengine's Viusalizer.
  560. out_file: Optional[str] = None,
  561. pred_score_thr: float = 0.3,
  562. step: int = 0) -> None:
  563. """Draw datasample and save to all backends.
  564. - If GT and prediction are plotted at the same time, they are
  565. displayed in a stitched image where the left image is the
  566. ground truth and the right image is the prediction.
  567. - If ``show`` is True, all storage backends are ignored, and
  568. the images will be displayed in a local window.
  569. - If ``out_file`` is specified, the drawn image will be
  570. saved to ``out_file``. t is usually used when the display
  571. is not available.
  572. Args:
  573. name (str): The image identifier.
  574. image (np.ndarray): The image to draw.
  575. data_sample (OptTrackSampleList): A data
  576. sample that contain annotations and predictions.
  577. Defaults to None.
  578. draw_gt (bool): Whether to draw GT TrackDataSample.
  579. Default to True.
  580. draw_pred (bool): Whether to draw Prediction TrackDataSample.
  581. Defaults to True.
  582. show (bool): Whether to display the drawn image. Default to False.
  583. wait_time (int): The interval of show (s). Defaults to 0.
  584. out_file (str): Path to output file. Defaults to None.
  585. pred_score_thr (float): The threshold to visualize the bboxes
  586. and masks. Defaults to 0.3.
  587. step (int): Global step value to record. Defaults to 0.
  588. """
  589. gt_img_data = None
  590. pred_img_data = None
  591. if data_sample is not None:
  592. data_sample = data_sample.cpu()
  593. if draw_gt and data_sample is not None:
  594. assert 'gt_instances' in data_sample
  595. gt_img_data = self._draw_instances(image, data_sample.gt_instances)
  596. if draw_pred and data_sample is not None:
  597. assert 'pred_track_instances' in data_sample
  598. pred_instances = data_sample.pred_track_instances
  599. if 'scores' in pred_instances:
  600. pred_instances = pred_instances[
  601. pred_instances.scores > pred_score_thr].cpu()
  602. pred_img_data = self._draw_instances(image, pred_instances)
  603. if gt_img_data is not None and pred_img_data is not None:
  604. drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
  605. elif gt_img_data is not None:
  606. drawn_img = gt_img_data
  607. else:
  608. drawn_img = pred_img_data
  609. if show:
  610. self.show(drawn_img, win_name=name, wait_time=wait_time)
  611. if out_file is not None:
  612. mmcv.imwrite(drawn_img[..., ::-1], out_file)
  613. else:
  614. self.add_image(name, drawn_img, step)