inference.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import warnings
  4. from pathlib import Path
  5. from typing import Optional, Sequence, Union
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from mmcv.ops import RoIPool
  10. from mmcv.transforms import Compose
  11. from mmengine.config import Config
  12. from mmengine.dataset import default_collate
  13. from mmengine.model.utils import revert_sync_batchnorm
  14. from mmengine.registry import init_default_scope
  15. from mmengine.runner import load_checkpoint
  16. from mmdet.registry import DATASETS
  17. from mmdet.utils import ConfigType
  18. from ..evaluation import get_classes
  19. from ..registry import MODELS
  20. from ..structures import DetDataSample, SampleList
  21. from ..utils import get_test_pipeline_cfg
  22. def init_detector(
  23. config: Union[str, Path, Config],
  24. checkpoint: Optional[str] = None,
  25. palette: str = 'none',
  26. device: str = 'cuda:0',
  27. cfg_options: Optional[dict] = None,
  28. ) -> nn.Module:
  29. """Initialize a detector from config file.
  30. Args:
  31. config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
  32. :obj:`Path`, or the config object.
  33. checkpoint (str, optional): Checkpoint path. If left as None, the model
  34. will not load any weights.
  35. palette (str): Color palette used for visualization. If palette
  36. is stored in checkpoint, use checkpoint's palette first, otherwise
  37. use externally passed palette. Currently, supports 'coco', 'voc',
  38. 'citys' and 'random'. Defaults to none.
  39. device (str): The device where the anchors will be put on.
  40. Defaults to cuda:0.
  41. cfg_options (dict, optional): Options to override some settings in
  42. the used config.
  43. Returns:
  44. nn.Module: The constructed detector.
  45. """
  46. if isinstance(config, (str, Path)):
  47. config = Config.fromfile(config)
  48. elif not isinstance(config, Config):
  49. raise TypeError('config must be a filename or Config object, '
  50. f'but got {type(config)}')
  51. if cfg_options is not None:
  52. config.merge_from_dict(cfg_options)
  53. elif 'init_cfg' in config.model.backbone:
  54. config.model.backbone.init_cfg = None
  55. scope = config.get('default_scope', 'mmdet')
  56. if scope is not None:
  57. init_default_scope(config.get('default_scope', 'mmdet'))
  58. model = MODELS.build(config.model)
  59. model = revert_sync_batchnorm(model)
  60. if checkpoint is None:
  61. warnings.simplefilter('once')
  62. warnings.warn('checkpoint is None, use COCO classes by default.')
  63. model.dataset_meta = {'classes': get_classes('coco')}
  64. else:
  65. checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
  66. # Weights converted from elsewhere may not have meta fields.
  67. checkpoint_meta = checkpoint.get('meta', {})
  68. # save the dataset_meta in the model for convenience
  69. if 'dataset_meta' in checkpoint_meta:
  70. # mmdet 3.x, all keys should be lowercase
  71. model.dataset_meta = {
  72. k.lower(): v
  73. for k, v in checkpoint_meta['dataset_meta'].items()
  74. }
  75. elif 'CLASSES' in checkpoint_meta:
  76. # < mmdet 3.x
  77. classes = checkpoint_meta['CLASSES']
  78. model.dataset_meta = {'classes': classes}
  79. else:
  80. warnings.simplefilter('once')
  81. warnings.warn(
  82. 'dataset_meta or class names are not saved in the '
  83. 'checkpoint\'s meta data, use COCO classes by default.')
  84. model.dataset_meta = {'classes': get_classes('coco')}
  85. # Priority: args.palette -> config -> checkpoint
  86. if palette != 'none':
  87. model.dataset_meta['palette'] = palette
  88. else:
  89. test_dataset_cfg = copy.deepcopy(config.test_dataloader.dataset)
  90. # lazy init. We only need the metainfo.
  91. test_dataset_cfg['lazy_init'] = True
  92. metainfo = DATASETS.build(test_dataset_cfg).metainfo
  93. cfg_palette = metainfo.get('palette', None)
  94. if cfg_palette is not None:
  95. model.dataset_meta['palette'] = cfg_palette
  96. else:
  97. if 'palette' not in model.dataset_meta:
  98. warnings.warn(
  99. 'palette does not exist, random is used by default. '
  100. 'You can also set the palette to customize.')
  101. model.dataset_meta['palette'] = 'random'
  102. model.cfg = config # save the config in the model for convenience
  103. model.to(device)
  104. model.eval()
  105. return model
  106. ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
  107. def inference_detector(
  108. model: nn.Module,
  109. imgs: ImagesType,
  110. test_pipeline: Optional[Compose] = None,
  111. text_prompt: Optional[str] = None,
  112. custom_entities: bool = False,
  113. ) -> Union[DetDataSample, SampleList]:
  114. """Inference image(s) with the detector.
  115. Args:
  116. model (nn.Module): The loaded detector.
  117. imgs (str, ndarray, Sequence[str/ndarray]):
  118. Either image files or loaded images.
  119. test_pipeline (:obj:`Compose`): Test pipeline.
  120. Returns:
  121. :obj:`DetDataSample` or list[:obj:`DetDataSample`]:
  122. If imgs is a list or tuple, the same length list type results
  123. will be returned, otherwise return the detection results directly.
  124. """
  125. if isinstance(imgs, (list, tuple)):
  126. is_batch = True
  127. else:
  128. imgs = [imgs]
  129. is_batch = False
  130. cfg = model.cfg
  131. if test_pipeline is None:
  132. cfg = cfg.copy()
  133. test_pipeline = get_test_pipeline_cfg(cfg)
  134. if isinstance(imgs[0], np.ndarray):
  135. # Calling this method across libraries will result
  136. # in module unregistered error if not prefixed with mmdet.
  137. test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
  138. test_pipeline = Compose(test_pipeline)
  139. if model.data_preprocessor.device.type == 'cpu':
  140. for m in model.modules():
  141. assert not isinstance(
  142. m, RoIPool
  143. ), 'CPU inference with RoIPool is not supported currently.'
  144. result_list = []
  145. for i, img in enumerate(imgs):
  146. # prepare data
  147. if isinstance(img, np.ndarray):
  148. # TODO: remove img_id.
  149. data_ = dict(img=img, img_id=0)
  150. else:
  151. # TODO: remove img_id.
  152. data_ = dict(img_path=img, img_id=0)
  153. if text_prompt:
  154. data_['text'] = text_prompt
  155. data_['custom_entities'] = custom_entities
  156. # build the data pipeline
  157. data_ = test_pipeline(data_)
  158. data_['inputs'] = [data_['inputs']]
  159. data_['data_samples'] = [data_['data_samples']]
  160. # forward the model
  161. with torch.no_grad():
  162. results = model.test_step(data_)[0]
  163. result_list.append(results)
  164. if not is_batch:
  165. return result_list[0]
  166. else:
  167. return result_list
  168. # TODO: Awaiting refactoring
  169. async def async_inference_detector(model, imgs):
  170. """Async inference image(s) with the detector.
  171. Args:
  172. model (nn.Module): The loaded detector.
  173. img (str | ndarray): Either image files or loaded images.
  174. Returns:
  175. Awaitable detection results.
  176. """
  177. if not isinstance(imgs, (list, tuple)):
  178. imgs = [imgs]
  179. cfg = model.cfg
  180. if isinstance(imgs[0], np.ndarray):
  181. cfg = cfg.copy()
  182. # set loading pipeline type
  183. cfg.data.test.pipeline[0].type = 'LoadImageFromNDArray'
  184. # cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
  185. test_pipeline = Compose(cfg.data.test.pipeline)
  186. datas = []
  187. for img in imgs:
  188. # prepare data
  189. if isinstance(img, np.ndarray):
  190. # directly add img
  191. data = dict(img=img)
  192. else:
  193. # add information into dict
  194. data = dict(img_info=dict(filename=img), img_prefix=None)
  195. # build the data pipeline
  196. data = test_pipeline(data)
  197. datas.append(data)
  198. for m in model.modules():
  199. assert not isinstance(
  200. m,
  201. RoIPool), 'CPU inference with RoIPool is not supported currently.'
  202. # We don't restore `torch.is_grad_enabled()` value during concurrent
  203. # inference since execution can overlap
  204. torch.set_grad_enabled(False)
  205. results = await model.aforward_test(data, rescale=True)
  206. return results
  207. def build_test_pipeline(cfg: ConfigType) -> ConfigType:
  208. """Build test_pipeline for mot/vis demo. In mot/vis infer, original
  209. test_pipeline should remove the "LoadImageFromFile" and
  210. "LoadTrackAnnotations".
  211. Args:
  212. cfg (ConfigDict): The loaded config.
  213. Returns:
  214. ConfigType: new test_pipeline
  215. """
  216. # remove the "LoadImageFromFile" and "LoadTrackAnnotations" in pipeline
  217. transform_broadcaster = cfg.test_dataloader.dataset.pipeline[0].copy()
  218. for transform in transform_broadcaster['transforms']:
  219. if transform['type'] == 'Resize':
  220. transform_broadcaster['transforms'] = transform
  221. pack_track_inputs = cfg.test_dataloader.dataset.pipeline[-1].copy()
  222. test_pipeline = Compose([transform_broadcaster, pack_track_inputs])
  223. return test_pipeline
  224. def inference_mot(model: nn.Module, img: np.ndarray, frame_id: int,
  225. video_len: int) -> SampleList:
  226. """Inference image(s) with the mot model.
  227. Args:
  228. model (nn.Module): The loaded mot model.
  229. img (np.ndarray): Loaded image.
  230. frame_id (int): frame id.
  231. video_len (int): demo video length
  232. Returns:
  233. SampleList: The tracking data samples.
  234. """
  235. cfg = model.cfg
  236. data = dict(
  237. img=[img.astype(np.float32)],
  238. frame_id=[frame_id],
  239. ori_shape=[img.shape[:2]],
  240. img_id=[frame_id + 1],
  241. ori_video_length=[video_len])
  242. test_pipeline = build_test_pipeline(cfg)
  243. data = test_pipeline(data)
  244. if not next(model.parameters()).is_cuda:
  245. for m in model.modules():
  246. assert not isinstance(
  247. m, RoIPool
  248. ), 'CPU inference with RoIPool is not supported currently.'
  249. # forward the model
  250. with torch.no_grad():
  251. data = default_collate([data])
  252. result = model.test_step(data)[0]
  253. return result
  254. def init_track_model(config: Union[str, Config],
  255. checkpoint: Optional[str] = None,
  256. detector: Optional[str] = None,
  257. reid: Optional[str] = None,
  258. device: str = 'cuda:0',
  259. cfg_options: Optional[dict] = None) -> nn.Module:
  260. """Initialize a model from config file.
  261. Args:
  262. config (str or :obj:`mmengine.Config`): Config file path or the config
  263. object.
  264. checkpoint (Optional[str], optional): Checkpoint path. Defaults to
  265. None.
  266. detector (Optional[str], optional): Detector Checkpoint path, use in
  267. some tracking algorithms like sort. Defaults to None.
  268. reid (Optional[str], optional): Reid checkpoint path. use in
  269. some tracking algorithms like sort. Defaults to None.
  270. device (str, optional): The device that the model inferences on.
  271. Defaults to `cuda:0`.
  272. cfg_options (Optional[dict], optional): Options to override some
  273. settings in the used config. Defaults to None.
  274. Returns:
  275. nn.Module: The constructed model.
  276. """
  277. if isinstance(config, str):
  278. config = Config.fromfile(config)
  279. elif not isinstance(config, Config):
  280. raise TypeError('config must be a filename or Config object, '
  281. f'but got {type(config)}')
  282. if cfg_options is not None:
  283. config.merge_from_dict(cfg_options)
  284. model = MODELS.build(config.model)
  285. if checkpoint is not None:
  286. checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
  287. # Weights converted from elsewhere may not have meta fields.
  288. checkpoint_meta = checkpoint.get('meta', {})
  289. # save the dataset_meta in the model for convenience
  290. if 'dataset_meta' in checkpoint_meta:
  291. if 'CLASSES' in checkpoint_meta['dataset_meta']:
  292. value = checkpoint_meta['dataset_meta'].pop('CLASSES')
  293. checkpoint_meta['dataset_meta']['classes'] = value
  294. model.dataset_meta = checkpoint_meta['dataset_meta']
  295. if detector is not None:
  296. assert not (checkpoint and detector), \
  297. 'Error: checkpoint and detector checkpoint cannot both exist'
  298. load_checkpoint(model.detector, detector, map_location='cpu')
  299. if reid is not None:
  300. assert not (checkpoint and reid), \
  301. 'Error: checkpoint and reid checkpoint cannot both exist'
  302. load_checkpoint(model.reid, reid, map_location='cpu')
  303. # Some methods don't load checkpoints or checkpoints don't contain
  304. # 'dataset_meta'
  305. # VIS need dataset_meta, MOT don't need dataset_meta
  306. if not hasattr(model, 'dataset_meta'):
  307. warnings.warn('dataset_meta or class names are missed, '
  308. 'use None by default.')
  309. model.dataset_meta = {'classes': None}
  310. model.cfg = config # save the config in the model for convenience
  311. model.to(device)
  312. model.eval()
  313. return model