_utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from os.path import dirname, exists, join
  4. import numpy as np
  5. import torch
  6. from mmengine.config import Config
  7. from mmengine.dataset import pseudo_collate
  8. from mmengine.structures import InstanceData, PixelData
  9. from mmdet.utils.util_random import ensure_rng
  10. from ..registry import TASK_UTILS
  11. from ..structures import DetDataSample, TrackDataSample
  12. from ..structures.bbox import HorizontalBoxes
  13. def _get_config_directory():
  14. """Find the predefined detector config directory."""
  15. try:
  16. # Assume we are running in the source mmdetection repo
  17. repo_dpath = dirname(dirname(dirname(__file__)))
  18. except NameError:
  19. # For IPython development when this __file__ is not defined
  20. import mmdet
  21. repo_dpath = dirname(dirname(mmdet.__file__))
  22. config_dpath = join(repo_dpath, 'configs')
  23. if not exists(config_dpath):
  24. raise Exception('Cannot find config path')
  25. return config_dpath
  26. def _get_config_module(fname):
  27. """Load a configuration as a python module."""
  28. config_dpath = _get_config_directory()
  29. config_fpath = join(config_dpath, fname)
  30. config_mod = Config.fromfile(config_fpath)
  31. return config_mod
  32. def get_detector_cfg(fname):
  33. """Grab configs necessary to create a detector.
  34. These are deep copied to allow for safe modification of parameters without
  35. influencing other tests.
  36. """
  37. config = _get_config_module(fname)
  38. model = copy.deepcopy(config.model)
  39. return model
  40. def get_roi_head_cfg(fname):
  41. """Grab configs necessary to create a roi_head.
  42. These are deep copied to allow for safe modification of parameters without
  43. influencing other tests.
  44. """
  45. config = _get_config_module(fname)
  46. model = copy.deepcopy(config.model)
  47. roi_head = model.roi_head
  48. train_cfg = None if model.train_cfg is None else model.train_cfg.rcnn
  49. test_cfg = None if model.test_cfg is None else model.test_cfg.rcnn
  50. roi_head.update(dict(train_cfg=train_cfg, test_cfg=test_cfg))
  51. return roi_head
  52. def _rand_bboxes(rng, num_boxes, w, h):
  53. cx, cy, bw, bh = rng.rand(num_boxes, 4).T
  54. tl_x = ((cx * w) - (w * bw / 2)).clip(0, w)
  55. tl_y = ((cy * h) - (h * bh / 2)).clip(0, h)
  56. br_x = ((cx * w) + (w * bw / 2)).clip(0, w)
  57. br_y = ((cy * h) + (h * bh / 2)).clip(0, h)
  58. bboxes = np.vstack([tl_x, tl_y, br_x, br_y]).T
  59. return bboxes
  60. def _rand_masks(rng, num_boxes, bboxes, img_w, img_h):
  61. from mmdet.structures.mask import BitmapMasks
  62. masks = np.zeros((num_boxes, img_h, img_w))
  63. for i, bbox in enumerate(bboxes):
  64. bbox = bbox.astype(np.int32)
  65. mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) >
  66. 0.3).astype(np.int64)
  67. masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask
  68. return BitmapMasks(masks, height=img_h, width=img_w)
  69. def demo_mm_inputs(batch_size=2,
  70. image_shapes=(3, 128, 128),
  71. num_items=None,
  72. num_classes=10,
  73. sem_seg_output_strides=1,
  74. with_mask=False,
  75. with_semantic=False,
  76. use_box_type=False,
  77. device='cpu',
  78. texts=None,
  79. custom_entities=False):
  80. """Create a superset of inputs needed to run test or train batches.
  81. Args:
  82. batch_size (int): batch size. Defaults to 2.
  83. image_shapes (List[tuple], Optional): image shape.
  84. Defaults to (3, 128, 128)
  85. num_items (None | List[int]): specifies the number
  86. of boxes in each batch item. Default to None.
  87. num_classes (int): number of different labels a
  88. box might have. Defaults to 10.
  89. with_mask (bool): Whether to return mask annotation.
  90. Defaults to False.
  91. with_semantic (bool): whether to return semantic.
  92. Defaults to False.
  93. device (str): Destination device type. Defaults to cpu.
  94. """
  95. rng = np.random.RandomState(0)
  96. if isinstance(image_shapes, list):
  97. assert len(image_shapes) == batch_size
  98. else:
  99. image_shapes = [image_shapes] * batch_size
  100. if isinstance(num_items, list):
  101. assert len(num_items) == batch_size
  102. if texts is not None:
  103. assert batch_size == len(texts)
  104. packed_inputs = []
  105. for idx in range(batch_size):
  106. image_shape = image_shapes[idx]
  107. c, h, w = image_shape
  108. image = rng.randint(0, 255, size=image_shape, dtype=np.uint8)
  109. mm_inputs = dict()
  110. mm_inputs['inputs'] = torch.from_numpy(image).to(device)
  111. img_meta = {
  112. 'img_id': idx,
  113. 'img_shape': image_shape[1:],
  114. 'ori_shape': image_shape[1:],
  115. 'filename': '<demo>.png',
  116. 'scale_factor': np.array([1.1, 1.2]),
  117. 'flip': False,
  118. 'flip_direction': None,
  119. 'border': [1, 1, 1, 1] # Only used by CenterNet
  120. }
  121. if texts:
  122. img_meta['text'] = texts[idx]
  123. img_meta['custom_entities'] = custom_entities
  124. data_sample = DetDataSample()
  125. data_sample.set_metainfo(img_meta)
  126. # gt_instances
  127. gt_instances = InstanceData()
  128. if num_items is None:
  129. num_boxes = rng.randint(1, 10)
  130. else:
  131. num_boxes = num_items[idx]
  132. bboxes = _rand_bboxes(rng, num_boxes, w, h)
  133. labels = rng.randint(1, num_classes, size=num_boxes)
  134. # TODO: remove this part when all model adapted with BaseBoxes
  135. if use_box_type:
  136. gt_instances.bboxes = HorizontalBoxes(bboxes, dtype=torch.float32)
  137. else:
  138. gt_instances.bboxes = torch.FloatTensor(bboxes)
  139. gt_instances.labels = torch.LongTensor(labels)
  140. if with_mask:
  141. masks = _rand_masks(rng, num_boxes, bboxes, w, h)
  142. gt_instances.masks = masks
  143. # TODO: waiting for ci to be fixed
  144. # masks = np.random.randint(0, 2, (len(bboxes), h, w), dtype=np.uint8)
  145. # gt_instances.mask = BitmapMasks(masks, h, w)
  146. data_sample.gt_instances = gt_instances
  147. # ignore_instances
  148. ignore_instances = InstanceData()
  149. bboxes = _rand_bboxes(rng, num_boxes, w, h)
  150. if use_box_type:
  151. ignore_instances.bboxes = HorizontalBoxes(
  152. bboxes, dtype=torch.float32)
  153. else:
  154. ignore_instances.bboxes = torch.FloatTensor(bboxes)
  155. data_sample.ignored_instances = ignore_instances
  156. # gt_sem_seg
  157. if with_semantic:
  158. # assume gt_semantic_seg using scale 1/8 of the img
  159. gt_semantic_seg = torch.from_numpy(
  160. np.random.randint(
  161. 0,
  162. num_classes, (1, h // sem_seg_output_strides,
  163. w // sem_seg_output_strides),
  164. dtype=np.uint8))
  165. gt_sem_seg_data = dict(sem_seg=gt_semantic_seg)
  166. data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
  167. mm_inputs['data_samples'] = data_sample.to(device)
  168. # TODO: gt_ignore
  169. packed_inputs.append(mm_inputs)
  170. data = pseudo_collate(packed_inputs)
  171. return data
  172. def demo_mm_proposals(image_shapes, num_proposals, device='cpu'):
  173. """Create a list of fake porposals.
  174. Args:
  175. image_shapes (list[tuple[int]]): Batch image shapes.
  176. num_proposals (int): The number of fake proposals.
  177. """
  178. rng = np.random.RandomState(0)
  179. results = []
  180. for img_shape in image_shapes:
  181. result = InstanceData()
  182. w, h = img_shape[1:]
  183. proposals = _rand_bboxes(rng, num_proposals, w, h)
  184. result.bboxes = torch.from_numpy(proposals).float()
  185. result.scores = torch.from_numpy(rng.rand(num_proposals)).float()
  186. result.labels = torch.zeros(num_proposals).long()
  187. results.append(result.to(device))
  188. return results
  189. def demo_mm_sampling_results(proposals_list,
  190. batch_gt_instances,
  191. batch_gt_instances_ignore=None,
  192. assigner_cfg=None,
  193. sampler_cfg=None,
  194. feats=None):
  195. """Create sample results that can be passed to BBoxHead.get_targets."""
  196. assert len(proposals_list) == len(batch_gt_instances)
  197. if batch_gt_instances_ignore is None:
  198. batch_gt_instances_ignore = [None for _ in batch_gt_instances]
  199. else:
  200. assert len(batch_gt_instances_ignore) == len(batch_gt_instances)
  201. default_assigner_cfg = dict(
  202. type='MaxIoUAssigner',
  203. pos_iou_thr=0.5,
  204. neg_iou_thr=0.5,
  205. min_pos_iou=0.5,
  206. ignore_iof_thr=-1)
  207. assigner_cfg = assigner_cfg if assigner_cfg is not None \
  208. else default_assigner_cfg
  209. default_sampler_cfg = dict(
  210. type='RandomSampler',
  211. num=512,
  212. pos_fraction=0.25,
  213. neg_pos_ub=-1,
  214. add_gt_as_proposals=True)
  215. sampler_cfg = sampler_cfg if sampler_cfg is not None \
  216. else default_sampler_cfg
  217. bbox_assigner = TASK_UTILS.build(assigner_cfg)
  218. bbox_sampler = TASK_UTILS.build(sampler_cfg)
  219. sampling_results = []
  220. for i in range(len(batch_gt_instances)):
  221. if feats is not None:
  222. feats = [lvl_feat[i][None] for lvl_feat in feats]
  223. # rename proposals.bboxes to proposals.priors
  224. proposals = proposals_list[i]
  225. proposals.priors = proposals.pop('bboxes')
  226. assign_result = bbox_assigner.assign(proposals, batch_gt_instances[i],
  227. batch_gt_instances_ignore[i])
  228. sampling_result = bbox_sampler.sample(
  229. assign_result, proposals, batch_gt_instances[i], feats=feats)
  230. sampling_results.append(sampling_result)
  231. return sampling_results
  232. def demo_track_inputs(batch_size=1,
  233. num_frames=2,
  234. key_frames_inds=None,
  235. image_shapes=(3, 128, 128),
  236. num_items=None,
  237. num_classes=1,
  238. with_mask=False,
  239. with_semantic=False):
  240. """Create a superset of inputs needed to run test or train batches.
  241. Args:
  242. batch_size (int): batch size. Default to 1.
  243. num_frames (int): The number of frames.
  244. key_frames_inds (List): The indices of key frames.
  245. image_shapes (List[tuple], Optional): image shape.
  246. Default to (3, 128, 128)
  247. num_items (None | List[int]): specifies the number
  248. of boxes in each batch item. Default to None.
  249. num_classes (int): number of different labels a
  250. box might have. Default to 1.
  251. with_mask (bool): Whether to return mask annotation.
  252. Defaults to False.
  253. with_semantic (bool): whether to return semantic.
  254. Default to False.
  255. """
  256. rng = np.random.RandomState(0)
  257. # Make sure the length of image_shapes is equal to ``batch_size``
  258. if isinstance(image_shapes, list):
  259. assert len(image_shapes) == batch_size
  260. else:
  261. image_shapes = [image_shapes] * batch_size
  262. packed_inputs = []
  263. for idx in range(batch_size):
  264. mm_inputs = dict(inputs=dict())
  265. _, h, w = image_shapes[idx]
  266. imgs = rng.randint(
  267. 0, 255, size=(num_frames, *image_shapes[idx]), dtype=np.uint8)
  268. mm_inputs['inputs'] = torch.from_numpy(imgs)
  269. img_meta = {
  270. 'img_id': idx,
  271. 'img_shape': image_shapes[idx][-2:],
  272. 'ori_shape': image_shapes[idx][-2:],
  273. 'filename': '<demo>.png',
  274. 'scale_factor': np.array([1.1, 1.2]),
  275. 'flip': False,
  276. 'flip_direction': None,
  277. 'is_video_data': True,
  278. }
  279. video_data_samples = []
  280. for i in range(num_frames):
  281. data_sample = DetDataSample()
  282. img_meta['frame_id'] = i
  283. data_sample.set_metainfo(img_meta)
  284. # gt_instances
  285. gt_instances = InstanceData()
  286. if num_items is None:
  287. num_boxes = rng.randint(1, 10)
  288. else:
  289. num_boxes = num_items[idx]
  290. bboxes = _rand_bboxes(rng, num_boxes, w, h)
  291. labels = rng.randint(0, num_classes, size=num_boxes)
  292. instances_id = rng.randint(100, num_classes + 100, size=num_boxes)
  293. gt_instances.bboxes = torch.FloatTensor(bboxes)
  294. gt_instances.labels = torch.LongTensor(labels)
  295. gt_instances.instances_ids = torch.LongTensor(instances_id)
  296. if with_mask:
  297. masks = _rand_masks(rng, num_boxes, bboxes, w, h)
  298. gt_instances.masks = masks
  299. data_sample.gt_instances = gt_instances
  300. # ignore_instances
  301. ignore_instances = InstanceData()
  302. bboxes = _rand_bboxes(rng, num_boxes, w, h)
  303. ignore_instances.bboxes = bboxes
  304. data_sample.ignored_instances = ignore_instances
  305. video_data_samples.append(data_sample)
  306. track_data_sample = TrackDataSample()
  307. track_data_sample.video_data_samples = video_data_samples
  308. if key_frames_inds is not None:
  309. assert isinstance(
  310. key_frames_inds,
  311. list) and len(key_frames_inds) < num_frames and max(
  312. key_frames_inds) < num_frames
  313. ref_frames_inds = [
  314. i for i in range(num_frames) if i not in key_frames_inds
  315. ]
  316. track_data_sample.set_metainfo(
  317. dict(key_frames_inds=key_frames_inds))
  318. track_data_sample.set_metainfo(
  319. dict(ref_frames_inds=ref_frames_inds))
  320. mm_inputs['data_samples'] = track_data_sample
  321. # TODO: gt_ignore
  322. packed_inputs.append(mm_inputs)
  323. data = pseudo_collate(packed_inputs)
  324. return data
  325. def random_boxes(num=1, scale=1, rng=None):
  326. """Simple version of ``kwimage.Boxes.random``
  327. Returns:
  328. Tensor: shape (n, 4) in x1, y1, x2, y2 format.
  329. References:
  330. https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390 # noqa: E501
  331. Example:
  332. >>> num = 3
  333. >>> scale = 512
  334. >>> rng = 0
  335. >>> boxes = random_boxes(num, scale, rng)
  336. >>> print(boxes)
  337. tensor([[280.9925, 278.9802, 308.6148, 366.1769],
  338. [216.9113, 330.6978, 224.0446, 456.5878],
  339. [405.3632, 196.3221, 493.3953, 270.7942]])
  340. """
  341. rng = ensure_rng(rng)
  342. tlbr = rng.rand(num, 4).astype(np.float32)
  343. tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
  344. tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
  345. br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
  346. br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
  347. tlbr[:, 0] = tl_x * scale
  348. tlbr[:, 1] = tl_y * scale
  349. tlbr[:, 2] = br_x * scale
  350. tlbr[:, 3] = br_y * scale
  351. boxes = torch.from_numpy(tlbr)
  352. return boxes
  353. # TODO: Support full ceph
  354. def replace_to_ceph(cfg):
  355. backend_args = dict(
  356. backend='petrel',
  357. path_mapping=dict({
  358. './data/': 's3://openmmlab/datasets/detection/',
  359. 'data/': 's3://openmmlab/datasets/detection/'
  360. }))
  361. # TODO: name is a reserved interface, which will be used later.
  362. def _process_pipeline(dataset, name):
  363. def replace_img(pipeline):
  364. if pipeline['type'] == 'LoadImageFromFile':
  365. pipeline['backend_args'] = backend_args
  366. def replace_ann(pipeline):
  367. if pipeline['type'] == 'LoadAnnotations' or pipeline[
  368. 'type'] == 'LoadPanopticAnnotations':
  369. pipeline['backend_args'] = backend_args
  370. if 'pipeline' in dataset:
  371. replace_img(dataset.pipeline[0])
  372. replace_ann(dataset.pipeline[1])
  373. if 'dataset' in dataset:
  374. # dataset wrapper
  375. replace_img(dataset.dataset.pipeline[0])
  376. replace_ann(dataset.dataset.pipeline[1])
  377. else:
  378. # dataset wrapper
  379. replace_img(dataset.dataset.pipeline[0])
  380. replace_ann(dataset.dataset.pipeline[1])
  381. def _process_evaluator(evaluator, name):
  382. if evaluator['type'] == 'CocoPanopticMetric':
  383. evaluator['backend_args'] = backend_args
  384. # half ceph
  385. _process_pipeline(cfg.train_dataloader.dataset, cfg.filename)
  386. _process_pipeline(cfg.val_dataloader.dataset, cfg.filename)
  387. _process_pipeline(cfg.test_dataloader.dataset, cfg.filename)
  388. _process_evaluator(cfg.val_evaluator, cfg.filename)
  389. _process_evaluator(cfg.test_evaluator, cfg.filename)