loading.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple, Union
  3. import mmcv
  4. import numpy as np
  5. import pycocotools.mask as maskUtils
  6. import torch
  7. from mmcv.transforms import BaseTransform
  8. from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
  9. from mmcv.transforms import LoadImageFromFile
  10. from mmengine.fileio import get
  11. from mmengine.structures import BaseDataElement
  12. from mmdet.registry import TRANSFORMS
  13. from mmdet.structures.bbox import get_box_type
  14. from mmdet.structures.bbox.box_type import autocast_box_type
  15. from mmdet.structures.mask import BitmapMasks, PolygonMasks
  16. @TRANSFORMS.register_module()
  17. class LoadImageFromNDArray(LoadImageFromFile):
  18. """Load an image from ``results['img']``.
  19. Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
  20. :obj:`np.ndarray` in ``results['img']``. Can be used when loading image
  21. from webcam.
  22. Required Keys:
  23. - img
  24. Modified Keys:
  25. - img
  26. - img_path
  27. - img_shape
  28. - ori_shape
  29. Args:
  30. to_float32 (bool): Whether to convert the loaded image to a float32
  31. numpy array. If set to False, the loaded image is an uint8 array.
  32. Defaults to False.
  33. """
  34. def transform(self, results: dict) -> dict:
  35. """Transform function to add image meta information.
  36. Args:
  37. results (dict): Result dict with Webcam read image in
  38. ``results['img']``.
  39. Returns:
  40. dict: The dict contains loaded image and meta information.
  41. """
  42. img = results['img']
  43. if self.to_float32:
  44. img = img.astype(np.float32)
  45. results['img_path'] = None
  46. results['img'] = img
  47. results['img_shape'] = img.shape[:2]
  48. results['ori_shape'] = img.shape[:2]
  49. return results
  50. @TRANSFORMS.register_module()
  51. class LoadMultiChannelImageFromFiles(BaseTransform):
  52. """Load multi-channel images from a list of separate channel files.
  53. Required Keys:
  54. - img_path
  55. Modified Keys:
  56. - img
  57. - img_shape
  58. - ori_shape
  59. Args:
  60. to_float32 (bool): Whether to convert the loaded image to a float32
  61. numpy array. If set to False, the loaded image is an uint8 array.
  62. Defaults to False.
  63. color_type (str): The flag argument for :func:``mmcv.imfrombytes``.
  64. Defaults to 'unchanged'.
  65. imdecode_backend (str): The image decoding backend type. The backend
  66. argument for :func:``mmcv.imfrombytes``.
  67. See :func:``mmcv.imfrombytes`` for details.
  68. Defaults to 'cv2'.
  69. file_client_args (dict): Arguments to instantiate the
  70. corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
  71. backend_args (dict, optional): Arguments to instantiate the
  72. corresponding backend in mmdet >= 3.0.0rc7. Defaults to None.
  73. """
  74. def __init__(
  75. self,
  76. to_float32: bool = False,
  77. color_type: str = 'unchanged',
  78. imdecode_backend: str = 'cv2',
  79. file_client_args: dict = None,
  80. backend_args: dict = None,
  81. ) -> None:
  82. self.to_float32 = to_float32
  83. self.color_type = color_type
  84. self.imdecode_backend = imdecode_backend
  85. self.backend_args = backend_args
  86. if file_client_args is not None:
  87. raise RuntimeError(
  88. 'The `file_client_args` is deprecated, '
  89. 'please use `backend_args` instead, please refer to'
  90. 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
  91. )
  92. def transform(self, results: dict) -> dict:
  93. """Transform functions to load multiple images and get images meta
  94. information.
  95. Args:
  96. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  97. Returns:
  98. dict: The dict contains loaded images and meta information.
  99. """
  100. assert isinstance(results['img_path'], list)
  101. img = []
  102. for name in results['img_path']:
  103. img_bytes = get(name, backend_args=self.backend_args)
  104. img.append(
  105. mmcv.imfrombytes(
  106. img_bytes,
  107. flag=self.color_type,
  108. backend=self.imdecode_backend))
  109. img = np.stack(img, axis=-1)
  110. if self.to_float32:
  111. img = img.astype(np.float32)
  112. results['img'] = img
  113. results['img_shape'] = img.shape[:2]
  114. results['ori_shape'] = img.shape[:2]
  115. return results
  116. def __repr__(self):
  117. repr_str = (f'{self.__class__.__name__}('
  118. f'to_float32={self.to_float32}, '
  119. f"color_type='{self.color_type}', "
  120. f"imdecode_backend='{self.imdecode_backend}', "
  121. f'backend_args={self.backend_args})')
  122. return repr_str
  123. @TRANSFORMS.register_module()
  124. class LoadAnnotations(MMCV_LoadAnnotations):
  125. """Load and process the ``instances`` and ``seg_map`` annotation provided
  126. by dataset.
  127. The annotation format is as the following:
  128. .. code-block:: python
  129. {
  130. 'instances':
  131. [
  132. {
  133. # List of 4 numbers representing the bounding box of the
  134. # instance, in (x1, y1, x2, y2) order.
  135. 'bbox': [x1, y1, x2, y2],
  136. # Label of image classification.
  137. 'bbox_label': 1,
  138. # Used in instance/panoptic segmentation. The segmentation mask
  139. # of the instance or the information of segments.
  140. # 1. If list[list[float]], it represents a list of polygons,
  141. # one for each connected component of the object. Each
  142. # list[float] is one simple polygon in the format of
  143. # [x1, y1, ..., xn, yn] (n≥3). The Xs and Ys are absolute
  144. # coordinates in unit of pixels.
  145. # 2. If dict, it represents the per-pixel segmentation mask in
  146. # COCO’s compressed RLE format. The dict should have keys
  147. # “size” and “counts”. Can be loaded by pycocotools
  148. 'mask': list[list[float]] or dict,
  149. }
  150. ]
  151. # Filename of semantic or panoptic segmentation ground truth file.
  152. 'seg_map_path': 'a/b/c'
  153. }
  154. After this module, the annotation has been changed to the format below:
  155. .. code-block:: python
  156. {
  157. # In (x1, y1, x2, y2) order, float type. N is the number of bboxes
  158. # in an image
  159. 'gt_bboxes': BaseBoxes(N, 4)
  160. # In int type.
  161. 'gt_bboxes_labels': np.ndarray(N, )
  162. # In built-in class
  163. 'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
  164. # In uint8 type.
  165. 'gt_seg_map': np.ndarray (H, W)
  166. # in (x, y, v) order, float type.
  167. }
  168. Required Keys:
  169. - height
  170. - width
  171. - instances
  172. - bbox (optional)
  173. - bbox_label
  174. - mask (optional)
  175. - ignore_flag
  176. - seg_map_path (optional)
  177. Added Keys:
  178. - gt_bboxes (BaseBoxes[torch.float32])
  179. - gt_bboxes_labels (np.int64)
  180. - gt_masks (BitmapMasks | PolygonMasks)
  181. - gt_seg_map (np.uint8)
  182. - gt_ignore_flags (bool)
  183. Args:
  184. with_bbox (bool): Whether to parse and load the bbox annotation.
  185. Defaults to True.
  186. with_label (bool): Whether to parse and load the label annotation.
  187. Defaults to True.
  188. with_mask (bool): Whether to parse and load the mask annotation.
  189. Default: False.
  190. with_seg (bool): Whether to parse and load the semantic segmentation
  191. annotation. Defaults to False.
  192. poly2mask (bool): Whether to convert mask to bitmap. Default: True.
  193. box_type (str): The box type used to wrap the bboxes. If ``box_type``
  194. is None, gt_bboxes will keep being np.ndarray. Defaults to 'hbox'.
  195. reduce_zero_label (bool): Whether reduce all label value
  196. by 1. Usually used for datasets where 0 is background label.
  197. Defaults to False.
  198. ignore_index (int): The label index to be ignored.
  199. Valid only if reduce_zero_label is true. Defaults is 255.
  200. imdecode_backend (str): The image decoding backend type. The backend
  201. argument for :func:``mmcv.imfrombytes``.
  202. See :fun:``mmcv.imfrombytes`` for details.
  203. Defaults to 'cv2'.
  204. backend_args (dict, optional): Arguments to instantiate the
  205. corresponding backend. Defaults to None.
  206. """
  207. def __init__(
  208. self,
  209. with_mask: bool = False,
  210. poly2mask: bool = True,
  211. box_type: str = 'hbox',
  212. # use for semseg
  213. reduce_zero_label: bool = False,
  214. ignore_index: int = 255,
  215. **kwargs) -> None:
  216. super(LoadAnnotations, self).__init__(**kwargs)
  217. self.with_mask = with_mask
  218. self.poly2mask = poly2mask
  219. self.box_type = box_type
  220. self.reduce_zero_label = reduce_zero_label
  221. self.ignore_index = ignore_index
  222. def _load_bboxes(self, results: dict) -> None:
  223. """Private function to load bounding box annotations.
  224. Args:
  225. results (dict): Result dict from :obj:``mmengine.BaseDataset``.
  226. Returns:
  227. dict: The dict contains loaded bounding box annotations.
  228. """
  229. gt_bboxes = []
  230. gt_ignore_flags = []
  231. for instance in results.get('instances', []):
  232. gt_bboxes.append(instance['bbox'])
  233. gt_ignore_flags.append(instance['ignore_flag'])
  234. if self.box_type is None:
  235. results['gt_bboxes'] = np.array(
  236. gt_bboxes, dtype=np.float32).reshape((-1, 4))
  237. else:
  238. _, box_type_cls = get_box_type(self.box_type)
  239. results['gt_bboxes'] = box_type_cls(gt_bboxes, dtype=torch.float32)
  240. results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
  241. def _load_labels(self, results: dict) -> None:
  242. """Private function to load label annotations.
  243. Args:
  244. results (dict): Result dict from :obj:``mmengine.BaseDataset``.
  245. Returns:
  246. dict: The dict contains loaded label annotations.
  247. """
  248. gt_bboxes_labels = []
  249. for instance in results.get('instances', []):
  250. gt_bboxes_labels.append(instance['bbox_label'])
  251. # TODO: Inconsistent with mmcv, consider how to deal with it later.
  252. results['gt_bboxes_labels'] = np.array(
  253. gt_bboxes_labels, dtype=np.int64)
  254. def _poly2mask(self, mask_ann: Union[list, dict], img_h: int,
  255. img_w: int) -> np.ndarray:
  256. """Private function to convert masks represented with polygon to
  257. bitmaps.
  258. Args:
  259. mask_ann (list | dict): Polygon mask annotation input.
  260. img_h (int): The height of output mask.
  261. img_w (int): The width of output mask.
  262. Returns:
  263. np.ndarray: The decode bitmap mask of shape (img_h, img_w).
  264. """
  265. if isinstance(mask_ann, list):
  266. # polygon -- a single object might consist of multiple parts
  267. # we merge all parts into one mask rle code
  268. rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
  269. rle = maskUtils.merge(rles)
  270. elif isinstance(mask_ann['counts'], list):
  271. # uncompressed RLE
  272. rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
  273. else:
  274. # rle
  275. rle = mask_ann
  276. mask = maskUtils.decode(rle)
  277. return mask
  278. def _process_masks(self, results: dict) -> list:
  279. """Process gt_masks and filter invalid polygons.
  280. Args:
  281. results (dict): Result dict from :obj:``mmengine.BaseDataset``.
  282. Returns:
  283. list: Processed gt_masks.
  284. """
  285. gt_masks = []
  286. gt_ignore_flags = []
  287. for instance in results.get('instances', []):
  288. gt_mask = instance['mask']
  289. # If the annotation of segmentation mask is invalid,
  290. # ignore the whole instance.
  291. if isinstance(gt_mask, list):
  292. gt_mask = [
  293. np.array(polygon) for polygon in gt_mask
  294. if len(polygon) % 2 == 0 and len(polygon) >= 6
  295. ]
  296. if len(gt_mask) == 0:
  297. # ignore this instance and set gt_mask to a fake mask
  298. instance['ignore_flag'] = 1
  299. gt_mask = [np.zeros(6)]
  300. elif not self.poly2mask:
  301. # `PolygonMasks` requires a ploygon of format List[np.array],
  302. # other formats are invalid.
  303. instance['ignore_flag'] = 1
  304. gt_mask = [np.zeros(6)]
  305. elif isinstance(gt_mask, dict) and \
  306. not (gt_mask.get('counts') is not None and
  307. gt_mask.get('size') is not None and
  308. isinstance(gt_mask['counts'], (list, str))):
  309. # if gt_mask is a dict, it should include `counts` and `size`,
  310. # so that `BitmapMasks` can uncompressed RLE
  311. instance['ignore_flag'] = 1
  312. gt_mask = [np.zeros(6)]
  313. gt_masks.append(gt_mask)
  314. # re-process gt_ignore_flags
  315. gt_ignore_flags.append(instance['ignore_flag'])
  316. results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
  317. return gt_masks
  318. def _load_masks(self, results: dict) -> None:
  319. """Private function to load mask annotations.
  320. Args:
  321. results (dict): Result dict from :obj:``mmengine.BaseDataset``.
  322. """
  323. h, w = results['ori_shape']
  324. gt_masks = self._process_masks(results)
  325. if self.poly2mask:
  326. gt_masks = BitmapMasks(
  327. [self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
  328. else:
  329. # fake polygon masks will be ignored in `PackDetInputs`
  330. gt_masks = PolygonMasks([mask for mask in gt_masks], h, w)
  331. results['gt_masks'] = gt_masks
  332. def _load_seg_map(self, results: dict) -> None:
  333. """Private function to load semantic segmentation annotations.
  334. Args:
  335. results (dict): Result dict from :obj:``mmcv.BaseDataset``.
  336. Returns:
  337. dict: The dict contains loaded semantic segmentation annotations.
  338. """
  339. if results.get('seg_map_path', None) is None:
  340. return
  341. img_bytes = get(
  342. results['seg_map_path'], backend_args=self.backend_args)
  343. gt_semantic_seg = mmcv.imfrombytes(
  344. img_bytes, flag='unchanged',
  345. backend=self.imdecode_backend).squeeze()
  346. if self.reduce_zero_label:
  347. # avoid using underflow conversion
  348. gt_semantic_seg[gt_semantic_seg == 0] = self.ignore_index
  349. gt_semantic_seg = gt_semantic_seg - 1
  350. gt_semantic_seg[gt_semantic_seg == self.ignore_index -
  351. 1] = self.ignore_index
  352. # modify if custom classes
  353. if results.get('label_map', None) is not None:
  354. # Add deep copy to solve bug of repeatedly
  355. # replace `gt_semantic_seg`, which is reported in
  356. # https://github.com/open-mmlab/mmsegmentation/pull/1445/
  357. gt_semantic_seg_copy = gt_semantic_seg.copy()
  358. for old_id, new_id in results['label_map'].items():
  359. gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
  360. results['gt_seg_map'] = gt_semantic_seg
  361. results['ignore_index'] = self.ignore_index
  362. def transform(self, results: dict) -> dict:
  363. """Function to load multiple types annotations.
  364. Args:
  365. results (dict): Result dict from :obj:``mmengine.BaseDataset``.
  366. Returns:
  367. dict: The dict contains loaded bounding box, label and
  368. semantic segmentation.
  369. """
  370. if self.with_bbox:
  371. self._load_bboxes(results)
  372. if self.with_label:
  373. self._load_labels(results)
  374. if self.with_mask:
  375. self._load_masks(results)
  376. if self.with_seg:
  377. self._load_seg_map(results)
  378. return results
  379. def __repr__(self) -> str:
  380. repr_str = self.__class__.__name__
  381. repr_str += f'(with_bbox={self.with_bbox}, '
  382. repr_str += f'with_label={self.with_label}, '
  383. repr_str += f'with_mask={self.with_mask}, '
  384. repr_str += f'with_seg={self.with_seg}, '
  385. repr_str += f'poly2mask={self.poly2mask}, '
  386. repr_str += f"imdecode_backend='{self.imdecode_backend}', "
  387. repr_str += f'backend_args={self.backend_args})'
  388. return repr_str
  389. @TRANSFORMS.register_module()
  390. class LoadPanopticAnnotations(LoadAnnotations):
  391. """Load multiple types of panoptic annotations.
  392. The annotation format is as the following:
  393. .. code-block:: python
  394. {
  395. 'instances':
  396. [
  397. {
  398. # List of 4 numbers representing the bounding box of the
  399. # instance, in (x1, y1, x2, y2) order.
  400. 'bbox': [x1, y1, x2, y2],
  401. # Label of image classification.
  402. 'bbox_label': 1,
  403. },
  404. ...
  405. ]
  406. 'segments_info':
  407. [
  408. {
  409. # id = cls_id + instance_id * INSTANCE_OFFSET
  410. 'id': int,
  411. # Contiguous category id defined in dataset.
  412. 'category': int
  413. # Thing flag.
  414. 'is_thing': bool
  415. },
  416. ...
  417. ]
  418. # Filename of semantic or panoptic segmentation ground truth file.
  419. 'seg_map_path': 'a/b/c'
  420. }
  421. After this module, the annotation has been changed to the format below:
  422. .. code-block:: python
  423. {
  424. # In (x1, y1, x2, y2) order, float type. N is the number of bboxes
  425. # in an image
  426. 'gt_bboxes': BaseBoxes(N, 4)
  427. # In int type.
  428. 'gt_bboxes_labels': np.ndarray(N, )
  429. # In built-in class
  430. 'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
  431. # In uint8 type.
  432. 'gt_seg_map': np.ndarray (H, W)
  433. # in (x, y, v) order, float type.
  434. }
  435. Required Keys:
  436. - height
  437. - width
  438. - instances
  439. - bbox
  440. - bbox_label
  441. - ignore_flag
  442. - segments_info
  443. - id
  444. - category
  445. - is_thing
  446. - seg_map_path
  447. Added Keys:
  448. - gt_bboxes (BaseBoxes[torch.float32])
  449. - gt_bboxes_labels (np.int64)
  450. - gt_masks (BitmapMasks | PolygonMasks)
  451. - gt_seg_map (np.uint8)
  452. - gt_ignore_flags (bool)
  453. Args:
  454. with_bbox (bool): Whether to parse and load the bbox annotation.
  455. Defaults to True.
  456. with_label (bool): Whether to parse and load the label annotation.
  457. Defaults to True.
  458. with_mask (bool): Whether to parse and load the mask annotation.
  459. Defaults to True.
  460. with_seg (bool): Whether to parse and load the semantic segmentation
  461. annotation. Defaults to False.
  462. box_type (str): The box mode used to wrap the bboxes.
  463. imdecode_backend (str): The image decoding backend type. The backend
  464. argument for :func:``mmcv.imfrombytes``.
  465. See :fun:``mmcv.imfrombytes`` for details.
  466. Defaults to 'cv2'.
  467. backend_args (dict, optional): Arguments to instantiate the
  468. corresponding backend in mmdet >= 3.0.0rc7. Defaults to None.
  469. """
  470. def __init__(self,
  471. with_bbox: bool = True,
  472. with_label: bool = True,
  473. with_mask: bool = True,
  474. with_seg: bool = True,
  475. box_type: str = 'hbox',
  476. imdecode_backend: str = 'cv2',
  477. backend_args: dict = None) -> None:
  478. try:
  479. from panopticapi import utils
  480. except ImportError:
  481. raise ImportError(
  482. 'panopticapi is not installed, please install it by: '
  483. 'pip install git+https://github.com/cocodataset/'
  484. 'panopticapi.git.')
  485. self.rgb2id = utils.rgb2id
  486. super(LoadPanopticAnnotations, self).__init__(
  487. with_bbox=with_bbox,
  488. with_label=with_label,
  489. with_mask=with_mask,
  490. with_seg=with_seg,
  491. with_keypoints=False,
  492. box_type=box_type,
  493. imdecode_backend=imdecode_backend,
  494. backend_args=backend_args)
  495. def _load_masks_and_semantic_segs(self, results: dict) -> None:
  496. """Private function to load mask and semantic segmentation annotations.
  497. In gt_semantic_seg, the foreground label is from ``0`` to
  498. ``num_things - 1``, the background label is from ``num_things`` to
  499. ``num_things + num_stuff - 1``, 255 means the ignored label (``VOID``).
  500. Args:
  501. results (dict): Result dict from :obj:``mmdet.CustomDataset``.
  502. """
  503. # seg_map_path is None, when inference on the dataset without gts.
  504. if results.get('seg_map_path', None) is None:
  505. return
  506. img_bytes = get(
  507. results['seg_map_path'], backend_args=self.backend_args)
  508. pan_png = mmcv.imfrombytes(
  509. img_bytes, flag='color', channel_order='rgb').squeeze()
  510. pan_png = self.rgb2id(pan_png)
  511. gt_masks = []
  512. gt_seg = np.zeros_like(pan_png) + 255 # 255 as ignore
  513. for segment_info in results['segments_info']:
  514. mask = (pan_png == segment_info['id'])
  515. gt_seg = np.where(mask, segment_info['category'], gt_seg)
  516. # The legal thing masks
  517. if segment_info.get('is_thing'):
  518. gt_masks.append(mask.astype(np.uint8))
  519. if self.with_mask:
  520. h, w = results['ori_shape']
  521. gt_masks = BitmapMasks(gt_masks, h, w)
  522. results['gt_masks'] = gt_masks
  523. if self.with_seg:
  524. results['gt_seg_map'] = gt_seg
  525. def transform(self, results: dict) -> dict:
  526. """Function to load multiple types panoptic annotations.
  527. Args:
  528. results (dict): Result dict from :obj:``mmdet.CustomDataset``.
  529. Returns:
  530. dict: The dict contains loaded bounding box, label, mask and
  531. semantic segmentation annotations.
  532. """
  533. if self.with_bbox:
  534. self._load_bboxes(results)
  535. if self.with_label:
  536. self._load_labels(results)
  537. if self.with_mask or self.with_seg:
  538. # The tasks completed by '_load_masks' and '_load_semantic_segs'
  539. # in LoadAnnotations are merged to one function.
  540. self._load_masks_and_semantic_segs(results)
  541. return results
  542. @TRANSFORMS.register_module()
  543. class LoadProposals(BaseTransform):
  544. """Load proposal pipeline.
  545. Required Keys:
  546. - proposals
  547. Modified Keys:
  548. - proposals
  549. Args:
  550. num_max_proposals (int, optional): Maximum number of proposals to load.
  551. If not specified, all proposals will be loaded.
  552. """
  553. def __init__(self, num_max_proposals: Optional[int] = None) -> None:
  554. self.num_max_proposals = num_max_proposals
  555. def transform(self, results: dict) -> dict:
  556. """Transform function to load proposals from file.
  557. Args:
  558. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  559. Returns:
  560. dict: The dict contains loaded proposal annotations.
  561. """
  562. proposals = results['proposals']
  563. # the type of proposals should be `dict` or `InstanceData`
  564. assert isinstance(proposals, dict) \
  565. or isinstance(proposals, BaseDataElement)
  566. bboxes = proposals['bboxes'].astype(np.float32)
  567. assert bboxes.shape[1] == 4, \
  568. f'Proposals should have shapes (n, 4), but found {bboxes.shape}'
  569. if 'scores' in proposals:
  570. scores = proposals['scores'].astype(np.float32)
  571. assert bboxes.shape[0] == scores.shape[0]
  572. else:
  573. scores = np.zeros(bboxes.shape[0], dtype=np.float32)
  574. if self.num_max_proposals is not None:
  575. # proposals should sort by scores during dumping the proposals
  576. bboxes = bboxes[:self.num_max_proposals]
  577. scores = scores[:self.num_max_proposals]
  578. if len(bboxes) == 0:
  579. bboxes = np.zeros((0, 4), dtype=np.float32)
  580. scores = np.zeros(0, dtype=np.float32)
  581. results['proposals'] = bboxes
  582. results['proposals_scores'] = scores
  583. return results
  584. def __repr__(self):
  585. return self.__class__.__name__ + \
  586. f'(num_max_proposals={self.num_max_proposals})'
  587. @TRANSFORMS.register_module()
  588. class FilterAnnotations(BaseTransform):
  589. """Filter invalid annotations.
  590. Required Keys:
  591. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  592. - gt_bboxes_labels (np.int64) (optional)
  593. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  594. - gt_ignore_flags (bool) (optional)
  595. Modified Keys:
  596. - gt_bboxes (optional)
  597. - gt_bboxes_labels (optional)
  598. - gt_masks (optional)
  599. - gt_ignore_flags (optional)
  600. Args:
  601. min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth
  602. boxes. Default: (1., 1.)
  603. min_gt_mask_area (int): Minimum foreground area of ground truth masks.
  604. Default: 1
  605. by_box (bool): Filter instances with bounding boxes not meeting the
  606. min_gt_bbox_wh threshold. Default: True
  607. by_mask (bool): Filter instances with masks not meeting
  608. min_gt_mask_area threshold. Default: False
  609. keep_empty (bool): Whether to return None when it
  610. becomes an empty bbox after filtering. Defaults to True.
  611. """
  612. def __init__(self,
  613. min_gt_bbox_wh: Tuple[int, int] = (1, 1),
  614. min_gt_mask_area: int = 1,
  615. by_box: bool = True,
  616. by_mask: bool = False,
  617. keep_empty: bool = True) -> None:
  618. # TODO: add more filter options
  619. assert by_box or by_mask
  620. self.min_gt_bbox_wh = min_gt_bbox_wh
  621. self.min_gt_mask_area = min_gt_mask_area
  622. self.by_box = by_box
  623. self.by_mask = by_mask
  624. self.keep_empty = keep_empty
  625. @autocast_box_type()
  626. def transform(self, results: dict) -> Union[dict, None]:
  627. """Transform function to filter annotations.
  628. Args:
  629. results (dict): Result dict.
  630. Returns:
  631. dict: Updated result dict.
  632. """
  633. assert 'gt_bboxes' in results
  634. gt_bboxes = results['gt_bboxes']
  635. if gt_bboxes.shape[0] == 0:
  636. return results
  637. tests = []
  638. if self.by_box:
  639. tests.append(
  640. ((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
  641. (gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy())
  642. if self.by_mask:
  643. assert 'gt_masks' in results
  644. gt_masks = results['gt_masks']
  645. tests.append(gt_masks.areas >= self.min_gt_mask_area)
  646. keep = tests[0]
  647. for t in tests[1:]:
  648. keep = keep & t
  649. if not keep.any():
  650. if self.keep_empty:
  651. return None
  652. keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags')
  653. for key in keys:
  654. if key in results:
  655. results[key] = results[key][keep]
  656. return results
  657. def __repr__(self):
  658. return self.__class__.__name__ + \
  659. f'(min_gt_bbox_wh={self.min_gt_bbox_wh}, ' \
  660. f'keep_empty={self.keep_empty})'
  661. @TRANSFORMS.register_module()
  662. class LoadEmptyAnnotations(BaseTransform):
  663. """Load Empty Annotations for unlabeled images.
  664. Added Keys:
  665. - gt_bboxes (np.float32)
  666. - gt_bboxes_labels (np.int64)
  667. - gt_masks (BitmapMasks | PolygonMasks)
  668. - gt_seg_map (np.uint8)
  669. - gt_ignore_flags (bool)
  670. Args:
  671. with_bbox (bool): Whether to load the pseudo bbox annotation.
  672. Defaults to True.
  673. with_label (bool): Whether to load the pseudo label annotation.
  674. Defaults to True.
  675. with_mask (bool): Whether to load the pseudo mask annotation.
  676. Default: False.
  677. with_seg (bool): Whether to load the pseudo semantic segmentation
  678. annotation. Defaults to False.
  679. seg_ignore_label (int): The fill value used for segmentation map.
  680. Note this value must equals ``ignore_label`` in ``semantic_head``
  681. of the corresponding config. Defaults to 255.
  682. """
  683. def __init__(self,
  684. with_bbox: bool = True,
  685. with_label: bool = True,
  686. with_mask: bool = False,
  687. with_seg: bool = False,
  688. seg_ignore_label: int = 255) -> None:
  689. self.with_bbox = with_bbox
  690. self.with_label = with_label
  691. self.with_mask = with_mask
  692. self.with_seg = with_seg
  693. self.seg_ignore_label = seg_ignore_label
  694. def transform(self, results: dict) -> dict:
  695. """Transform function to load empty annotations.
  696. Args:
  697. results (dict): Result dict.
  698. Returns:
  699. dict: Updated result dict.
  700. """
  701. if self.with_bbox:
  702. results['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
  703. results['gt_ignore_flags'] = np.zeros((0, ), dtype=bool)
  704. if self.with_label:
  705. results['gt_bboxes_labels'] = np.zeros((0, ), dtype=np.int64)
  706. if self.with_mask:
  707. # TODO: support PolygonMasks
  708. h, w = results['img_shape']
  709. gt_masks = np.zeros((0, h, w), dtype=np.uint8)
  710. results['gt_masks'] = BitmapMasks(gt_masks, h, w)
  711. if self.with_seg:
  712. h, w = results['img_shape']
  713. results['gt_seg_map'] = self.seg_ignore_label * np.ones(
  714. (h, w), dtype=np.uint8)
  715. return results
  716. def __repr__(self) -> str:
  717. repr_str = self.__class__.__name__
  718. repr_str += f'(with_bbox={self.with_bbox}, '
  719. repr_str += f'with_label={self.with_label}, '
  720. repr_str += f'with_mask={self.with_mask}, '
  721. repr_str += f'with_seg={self.with_seg}, '
  722. repr_str += f'seg_ignore_label={self.seg_ignore_label})'
  723. return repr_str
  724. @TRANSFORMS.register_module()
  725. class InferencerLoader(BaseTransform):
  726. """Load an image from ``results['img']``.
  727. Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
  728. :obj:`np.ndarray` in ``results['img']``. Can be used when loading image
  729. from webcam.
  730. Required Keys:
  731. - img
  732. Modified Keys:
  733. - img
  734. - img_path
  735. - img_shape
  736. - ori_shape
  737. Args:
  738. to_float32 (bool): Whether to convert the loaded image to a float32
  739. numpy array. If set to False, the loaded image is an uint8 array.
  740. Defaults to False.
  741. """
  742. def __init__(self, **kwargs) -> None:
  743. super().__init__()
  744. self.from_file = TRANSFORMS.build(
  745. dict(type='LoadImageFromFile', **kwargs))
  746. self.from_ndarray = TRANSFORMS.build(
  747. dict(type='mmdet.LoadImageFromNDArray', **kwargs))
  748. def transform(self, results: Union[str, np.ndarray, dict]) -> dict:
  749. """Transform function to add image meta information.
  750. Args:
  751. results (str, np.ndarray or dict): The result.
  752. Returns:
  753. dict: The dict contains loaded image and meta information.
  754. """
  755. if isinstance(results, str):
  756. inputs = dict(img_path=results)
  757. elif isinstance(results, np.ndarray):
  758. inputs = dict(img=results)
  759. elif isinstance(results, dict):
  760. inputs = results
  761. else:
  762. raise NotImplementedError
  763. if 'img' in inputs:
  764. return self.from_ndarray(inputs)
  765. return self.from_file(inputs)
  766. @TRANSFORMS.register_module()
  767. class LoadTrackAnnotations(LoadAnnotations):
  768. """Load and process the ``instances`` and ``seg_map`` annotation provided
  769. by dataset. It must load ``instances_ids`` which is only used in the
  770. tracking tasks. The annotation format is as the following:
  771. .. code-block:: python
  772. {
  773. 'instances':
  774. [
  775. {
  776. # List of 4 numbers representing the bounding box of the
  777. # instance, in (x1, y1, x2, y2) order.
  778. 'bbox': [x1, y1, x2, y2],
  779. # Label of image classification.
  780. 'bbox_label': 1,
  781. # Used in tracking.
  782. # Id of instances.
  783. 'instance_id': 100,
  784. # Used in instance/panoptic segmentation. The segmentation mask
  785. # of the instance or the information of segments.
  786. # 1. If list[list[float]], it represents a list of polygons,
  787. # one for each connected component of the object. Each
  788. # list[float] is one simple polygon in the format of
  789. # [x1, y1, ..., xn, yn] (n≥3). The Xs and Ys are absolute
  790. # coordinates in unit of pixels.
  791. # 2. If dict, it represents the per-pixel segmentation mask in
  792. # COCO's compressed RLE format. The dict should have keys
  793. # “size” and “counts”. Can be loaded by pycocotools
  794. 'mask': list[list[float]] or dict,
  795. }
  796. ]
  797. # Filename of semantic or panoptic segmentation ground truth file.
  798. 'seg_map_path': 'a/b/c'
  799. }
  800. After this module, the annotation has been changed to the format below:
  801. .. code-block:: python
  802. {
  803. # In (x1, y1, x2, y2) order, float type. N is the number of bboxes
  804. # in an image
  805. 'gt_bboxes': np.ndarray(N, 4)
  806. # In int type.
  807. 'gt_bboxes_labels': np.ndarray(N, )
  808. # In built-in class
  809. 'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
  810. # In uint8 type.
  811. 'gt_seg_map': np.ndarray (H, W)
  812. # in (x, y, v) order, float type.
  813. }
  814. Required Keys:
  815. - height (optional)
  816. - width (optional)
  817. - instances
  818. - bbox (optional)
  819. - bbox_label
  820. - instance_id (optional)
  821. - mask (optional)
  822. - ignore_flag (optional)
  823. - seg_map_path (optional)
  824. Added Keys:
  825. - gt_bboxes (np.float32)
  826. - gt_bboxes_labels (np.int32)
  827. - gt_instances_ids (np.int32)
  828. - gt_masks (BitmapMasks | PolygonMasks)
  829. - gt_seg_map (np.uint8)
  830. - gt_ignore_flags (np.bool)
  831. """
  832. def __init__(self, **kwargs) -> None:
  833. super().__init__(**kwargs)
  834. def _load_bboxes(self, results: dict) -> None:
  835. """Private function to load bounding box annotations.
  836. Args:
  837. results (dict): Result dict from :obj:``mmcv.BaseDataset``.
  838. Returns:
  839. dict: The dict contains loaded bounding box annotations.
  840. """
  841. gt_bboxes = []
  842. gt_ignore_flags = []
  843. # TODO: use bbox_type
  844. for instance in results['instances']:
  845. # The datasets which are only format in evaluation don't have
  846. # groundtruth boxes.
  847. if 'bbox' in instance:
  848. gt_bboxes.append(instance['bbox'])
  849. if 'ignore_flag' in instance:
  850. gt_ignore_flags.append(instance['ignore_flag'])
  851. # TODO: check this case
  852. if len(gt_bboxes) != len(gt_ignore_flags):
  853. # There may be no ``gt_ignore_flags`` in some cases, we treat them
  854. # as all False in order to keep the length of ``gt_bboxes`` and
  855. # ``gt_ignore_flags`` the same
  856. gt_ignore_flags = [False] * len(gt_bboxes)
  857. results['gt_bboxes'] = np.array(
  858. gt_bboxes, dtype=np.float32).reshape(-1, 4)
  859. results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
  860. def _load_instances_ids(self, results: dict) -> None:
  861. """Private function to load instances id annotations.
  862. Args:
  863. results (dict): Result dict from :obj :obj:``mmcv.BaseDataset``.
  864. Returns:
  865. dict: The dict containing instances id annotations.
  866. """
  867. gt_instances_ids = []
  868. for instance in results['instances']:
  869. gt_instances_ids.append(instance['instance_id'])
  870. results['gt_instances_ids'] = np.array(
  871. gt_instances_ids, dtype=np.int32)
  872. def transform(self, results: dict) -> dict:
  873. """Function to load multiple types annotations.
  874. Args:
  875. results (dict): Result dict from :obj:``mmcv.BaseDataset``.
  876. Returns:
  877. dict: The dict contains loaded bounding box, label, instances id
  878. and semantic segmentation and keypoints annotations.
  879. """
  880. results = super().transform(results)
  881. self._load_instances_ids(results)
  882. return results
  883. def __repr__(self) -> str:
  884. repr_str = self.__class__.__name__
  885. repr_str += f'(with_bbox={self.with_bbox}, '
  886. repr_str += f'with_label={self.with_label}, '
  887. repr_str += f'with_mask={self.with_mask}, '
  888. repr_str += f'with_seg={self.with_seg}, '
  889. repr_str += f'poly2mask={self.poly2mask}, '
  890. repr_str += f"imdecode_backend='{self.imdecode_backend}', "
  891. repr_str += f'file_client_args={self.file_client_args})'
  892. return repr_str