det_data_sample.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional
  3. from mmengine.structures import BaseDataElement, InstanceData, PixelData
  4. class DetDataSample(BaseDataElement):
  5. """A data structure interface of MMDetection. They are used as interfaces
  6. between different components.
  7. The attributes in ``DetDataSample`` are divided into several parts:
  8. - ``proposals``(InstanceData): Region proposals used in two-stage
  9. detectors.
  10. - ``gt_instances``(InstanceData): Ground truth of instance annotations.
  11. - ``pred_instances``(InstanceData): Instances of detection predictions.
  12. - ``pred_track_instances``(InstanceData): Instances of tracking
  13. predictions.
  14. - ``ignored_instances``(InstanceData): Instances to be ignored during
  15. training/testing.
  16. - ``gt_panoptic_seg``(PixelData): Ground truth of panoptic
  17. segmentation.
  18. - ``pred_panoptic_seg``(PixelData): Prediction of panoptic
  19. segmentation.
  20. - ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation.
  21. - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
  22. Examples:
  23. >>> import torch
  24. >>> import numpy as np
  25. >>> from mmengine.structures import InstanceData
  26. >>> from mmdet.structures import DetDataSample
  27. >>> data_sample = DetDataSample()
  28. >>> img_meta = dict(img_shape=(800, 1196),
  29. ... pad_shape=(800, 1216))
  30. >>> gt_instances = InstanceData(metainfo=img_meta)
  31. >>> gt_instances.bboxes = torch.rand((5, 4))
  32. >>> gt_instances.labels = torch.rand((5,))
  33. >>> data_sample.gt_instances = gt_instances
  34. >>> assert 'img_shape' in data_sample.gt_instances.metainfo_keys()
  35. >>> len(data_sample.gt_instances)
  36. 5
  37. >>> print(data_sample)
  38. <DetDataSample(
  39. META INFORMATION
  40. DATA FIELDS
  41. gt_instances: <InstanceData(
  42. META INFORMATION
  43. pad_shape: (800, 1216)
  44. img_shape: (800, 1196)
  45. DATA FIELDS
  46. labels: tensor([0.8533, 0.1550, 0.5433, 0.7294, 0.5098])
  47. bboxes:
  48. tensor([[9.7725e-01, 5.8417e-01, 1.7269e-01, 6.5694e-01],
  49. [1.7894e-01, 5.1780e-01, 7.0590e-01, 4.8589e-01],
  50. [7.0392e-01, 6.6770e-01, 1.7520e-01, 1.4267e-01],
  51. [2.2411e-01, 5.1962e-01, 9.6953e-01, 6.6994e-01],
  52. [4.1338e-01, 2.1165e-01, 2.7239e-04, 6.8477e-01]])
  53. ) at 0x7f21fb1b9190>
  54. ) at 0x7f21fb1b9880>
  55. >>> pred_instances = InstanceData(metainfo=img_meta)
  56. >>> pred_instances.bboxes = torch.rand((5, 4))
  57. >>> pred_instances.scores = torch.rand((5,))
  58. >>> data_sample = DetDataSample(pred_instances=pred_instances)
  59. >>> assert 'pred_instances' in data_sample
  60. >>> pred_track_instances = InstanceData(metainfo=img_meta)
  61. >>> pred_track_instances.bboxes = torch.rand((5, 4))
  62. >>> pred_track_instances.scores = torch.rand((5,))
  63. >>> data_sample = DetDataSample(
  64. ... pred_track_instances=pred_track_instances)
  65. >>> assert 'pred_track_instances' in data_sample
  66. >>> data_sample = DetDataSample()
  67. >>> gt_instances_data = dict(
  68. ... bboxes=torch.rand(2, 4),
  69. ... labels=torch.rand(2),
  70. ... masks=np.random.rand(2, 2, 2))
  71. >>> gt_instances = InstanceData(**gt_instances_data)
  72. >>> data_sample.gt_instances = gt_instances
  73. >>> assert 'gt_instances' in data_sample
  74. >>> assert 'masks' in data_sample.gt_instances
  75. >>> data_sample = DetDataSample()
  76. >>> gt_panoptic_seg_data = dict(panoptic_seg=torch.rand(2, 4))
  77. >>> gt_panoptic_seg = PixelData(**gt_panoptic_seg_data)
  78. >>> data_sample.gt_panoptic_seg = gt_panoptic_seg
  79. >>> print(data_sample)
  80. <DetDataSample(
  81. META INFORMATION
  82. DATA FIELDS
  83. _gt_panoptic_seg: <BaseDataElement(
  84. META INFORMATION
  85. DATA FIELDS
  86. panoptic_seg: tensor([[0.7586, 0.1262, 0.2892, 0.9341],
  87. [0.3200, 0.7448, 0.1052, 0.5371]])
  88. ) at 0x7f66c2bb7730>
  89. gt_panoptic_seg: <BaseDataElement(
  90. META INFORMATION
  91. DATA FIELDS
  92. panoptic_seg: tensor([[0.7586, 0.1262, 0.2892, 0.9341],
  93. [0.3200, 0.7448, 0.1052, 0.5371]])
  94. ) at 0x7f66c2bb7730>
  95. ) at 0x7f66c2bb7280>
  96. >>> data_sample = DetDataSample()
  97. >>> gt_segm_seg_data = dict(segm_seg=torch.rand(2, 2, 2))
  98. >>> gt_segm_seg = PixelData(**gt_segm_seg_data)
  99. >>> data_sample.gt_segm_seg = gt_segm_seg
  100. >>> assert 'gt_segm_seg' in data_sample
  101. >>> assert 'segm_seg' in data_sample.gt_segm_seg
  102. """
  103. @property
  104. def proposals(self) -> InstanceData:
  105. return self._proposals
  106. @proposals.setter
  107. def proposals(self, value: InstanceData):
  108. self.set_field(value, '_proposals', dtype=InstanceData)
  109. @proposals.deleter
  110. def proposals(self):
  111. del self._proposals
  112. @property
  113. def gt_instances(self) -> InstanceData:
  114. return self._gt_instances
  115. @gt_instances.setter
  116. def gt_instances(self, value: InstanceData):
  117. self.set_field(value, '_gt_instances', dtype=InstanceData)
  118. @gt_instances.deleter
  119. def gt_instances(self):
  120. del self._gt_instances
  121. @property
  122. def pred_instances(self) -> InstanceData:
  123. return self._pred_instances
  124. @pred_instances.setter
  125. def pred_instances(self, value: InstanceData):
  126. self.set_field(value, '_pred_instances', dtype=InstanceData)
  127. @pred_instances.deleter
  128. def pred_instances(self):
  129. del self._pred_instances
  130. # directly add ``pred_track_instances`` in ``DetDataSample``
  131. # so that the ``TrackDataSample`` does not bother to access the
  132. # instance-level information.
  133. @property
  134. def pred_track_instances(self) -> InstanceData:
  135. return self._pred_track_instances
  136. @pred_track_instances.setter
  137. def pred_track_instances(self, value: InstanceData):
  138. self.set_field(value, '_pred_track_instances', dtype=InstanceData)
  139. @pred_track_instances.deleter
  140. def pred_track_instances(self):
  141. del self._pred_track_instances
  142. @property
  143. def ignored_instances(self) -> InstanceData:
  144. return self._ignored_instances
  145. @ignored_instances.setter
  146. def ignored_instances(self, value: InstanceData):
  147. self.set_field(value, '_ignored_instances', dtype=InstanceData)
  148. @ignored_instances.deleter
  149. def ignored_instances(self):
  150. del self._ignored_instances
  151. @property
  152. def gt_panoptic_seg(self) -> PixelData:
  153. return self._gt_panoptic_seg
  154. @gt_panoptic_seg.setter
  155. def gt_panoptic_seg(self, value: PixelData):
  156. self.set_field(value, '_gt_panoptic_seg', dtype=PixelData)
  157. @gt_panoptic_seg.deleter
  158. def gt_panoptic_seg(self):
  159. del self._gt_panoptic_seg
  160. @property
  161. def pred_panoptic_seg(self) -> PixelData:
  162. return self._pred_panoptic_seg
  163. @pred_panoptic_seg.setter
  164. def pred_panoptic_seg(self, value: PixelData):
  165. self.set_field(value, '_pred_panoptic_seg', dtype=PixelData)
  166. @pred_panoptic_seg.deleter
  167. def pred_panoptic_seg(self):
  168. del self._pred_panoptic_seg
  169. @property
  170. def gt_sem_seg(self) -> PixelData:
  171. return self._gt_sem_seg
  172. @gt_sem_seg.setter
  173. def gt_sem_seg(self, value: PixelData):
  174. self.set_field(value, '_gt_sem_seg', dtype=PixelData)
  175. @gt_sem_seg.deleter
  176. def gt_sem_seg(self):
  177. del self._gt_sem_seg
  178. @property
  179. def pred_sem_seg(self) -> PixelData:
  180. return self._pred_sem_seg
  181. @pred_sem_seg.setter
  182. def pred_sem_seg(self, value: PixelData):
  183. self.set_field(value, '_pred_sem_seg', dtype=PixelData)
  184. @pred_sem_seg.deleter
  185. def pred_sem_seg(self):
  186. del self._pred_sem_seg
  187. SampleList = List[DetDataSample]
  188. OptSampleList = Optional[SampleList]