track_data_sample.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Sequence
  3. import numpy as np
  4. import torch
  5. from mmengine.structures import BaseDataElement
  6. from .det_data_sample import DetDataSample
  7. class TrackDataSample(BaseDataElement):
  8. """A data structure interface of tracking task in MMDetection. It is used
  9. as interfaces between different components.
  10. This data structure can be viewd as a wrapper of multiple DetDataSample to
  11. some extent. Specifically, it only contains a property:
  12. ``video_data_samples`` which is a list of DetDataSample, each of which
  13. corresponds to a single frame. If you want to get the property of a single
  14. frame, you must first get the corresponding ``DetDataSample`` by indexing
  15. and then get the property of the frame, such as ``gt_instances``,
  16. ``pred_instances`` and so on. As for metainfo, it differs from
  17. ``DetDataSample`` in that each value corresponds to the metainfo key is a
  18. list where each element corresponds to information of a single frame.
  19. Examples:
  20. >>> import torch
  21. >>> from mmengine.structures import InstanceData
  22. >>> from mmdet.structures import DetDataSample, TrackDataSample
  23. >>> track_data_sample = TrackDataSample()
  24. >>> # set the 1st frame
  25. >>> frame1_data_sample = DetDataSample(metainfo=dict(
  26. ... img_shape=(100, 100), frame_id=0))
  27. >>> frame1_gt_instances = InstanceData()
  28. >>> frame1_gt_instances.bbox = torch.zeros([2, 4])
  29. >>> frame1_data_sample.gt_instances = frame1_gt_instances
  30. >>> # set the 2nd frame
  31. >>> frame2_data_sample = DetDataSample(metainfo=dict(
  32. ... img_shape=(100, 100), frame_id=1))
  33. >>> frame2_gt_instances = InstanceData()
  34. >>> frame2_gt_instances.bbox = torch.ones([3, 4])
  35. >>> frame2_data_sample.gt_instances = frame2_gt_instances
  36. >>> track_data_sample.video_data_samples = [frame1_data_sample,
  37. ... frame2_data_sample]
  38. >>> # set metainfo for track_data_sample
  39. >>> track_data_sample.set_metainfo(dict(key_frames_inds=[0]))
  40. >>> track_data_sample.set_metainfo(dict(ref_frames_inds=[1]))
  41. >>> print(track_data_sample)
  42. <TrackDataSample(
  43. META INFORMATION
  44. key_frames_inds: [0]
  45. ref_frames_inds: [1]
  46. DATA FIELDS
  47. video_data_samples: [<DetDataSample(
  48. META INFORMATION
  49. img_shape: (100, 100)
  50. DATA FIELDS
  51. gt_instances: <InstanceData(
  52. META INFORMATION
  53. DATA FIELDS
  54. bbox: tensor([[0., 0., 0., 0.],
  55. [0., 0., 0., 0.]])
  56. ) at 0x7f639320dcd0>
  57. ) at 0x7f64bd223340>, <DetDataSample(
  58. META INFORMATION
  59. img_shape: (100, 100)
  60. DATA FIELDS
  61. gt_instances: <InstanceData(
  62. META INFORMATION
  63. DATA FIELDS
  64. bbox: tensor([[1., 1., 1., 1.],
  65. [1., 1., 1., 1.],
  66. [1., 1., 1., 1.]])
  67. ) at 0x7f64bd128b20>
  68. ) at 0x7f64bd1346d0>]
  69. ) at 0x7f64bd2237f0>
  70. >>> print(len(track_data_sample))
  71. 2
  72. >>> key_data_sample = track_data_sample.get_key_frames()
  73. >>> print(key_data_sample[0].frame_id)
  74. 0
  75. >>> ref_data_sample = track_data_sample.get_ref_frames()
  76. >>> print(ref_data_sample[0].frame_id)
  77. 1
  78. >>> frame1_data_sample = track_data_sample[0]
  79. >>> print(frame1_data_sample.gt_instances.bbox)
  80. tensor([[0., 0., 0., 0.],
  81. [0., 0., 0., 0.]])
  82. >>> # Tensor-like methods
  83. >>> cuda_track_data_sample = track_data_sample.to('cuda')
  84. >>> cuda_track_data_sample = track_data_sample.cuda()
  85. >>> cpu_track_data_sample = track_data_sample.cpu()
  86. >>> cpu_track_data_sample = track_data_sample.to('cpu')
  87. >>> fp16_instances = cuda_track_data_sample.to(
  88. ... device=None, dtype=torch.float16, non_blocking=False,
  89. ... copy=False, memory_format=torch.preserve_format)
  90. """
  91. @property
  92. def video_data_samples(self) -> List[DetDataSample]:
  93. return self._video_data_samples
  94. @video_data_samples.setter
  95. def video_data_samples(self, value: List[DetDataSample]):
  96. if isinstance(value, DetDataSample):
  97. value = [value]
  98. assert isinstance(value, list), 'video_data_samples must be a list'
  99. assert isinstance(
  100. value[0], DetDataSample
  101. ), 'video_data_samples must be a list of DetDataSample, but got '
  102. f'{value[0]}'
  103. self.set_field(value, '_video_data_samples', dtype=list)
  104. @video_data_samples.deleter
  105. def video_data_samples(self):
  106. del self._video_data_samples
  107. def __getitem__(self, index):
  108. assert hasattr(self,
  109. '_video_data_samples'), 'video_data_samples not set'
  110. return self._video_data_samples[index]
  111. def get_key_frames(self):
  112. assert hasattr(self, 'key_frames_inds'), \
  113. 'key_frames_inds not set'
  114. assert isinstance(self.key_frames_inds, Sequence)
  115. key_frames_info = []
  116. for index in self.key_frames_inds:
  117. key_frames_info.append(self[index])
  118. return key_frames_info
  119. def get_ref_frames(self):
  120. assert hasattr(self, 'ref_frames_inds'), \
  121. 'ref_frames_inds not set'
  122. ref_frames_info = []
  123. assert isinstance(self.ref_frames_inds, Sequence)
  124. for index in self.ref_frames_inds:
  125. ref_frames_info.append(self[index])
  126. return ref_frames_info
  127. def __len__(self):
  128. return len(self._video_data_samples) if hasattr(
  129. self, '_video_data_samples') else 0
  130. # TODO: add UT for this Tensor-like method
  131. # Tensor-like methods
  132. def to(self, *args, **kwargs) -> 'BaseDataElement':
  133. """Apply same name function to all tensors in data_fields."""
  134. new_data = self.new()
  135. for k, v_list in self.items():
  136. data_list = []
  137. for v in v_list:
  138. if hasattr(v, 'to'):
  139. v = v.to(*args, **kwargs)
  140. data_list.append(v)
  141. if len(data_list) > 0:
  142. new_data.set_data({f'{k}': data_list})
  143. return new_data
  144. # Tensor-like methods
  145. def cpu(self) -> 'BaseDataElement':
  146. """Convert all tensors to CPU in data."""
  147. new_data = self.new()
  148. for k, v_list in self.items():
  149. data_list = []
  150. for v in v_list:
  151. if isinstance(v, (torch.Tensor, BaseDataElement)):
  152. v = v.cpu()
  153. data_list.append(v)
  154. if len(data_list) > 0:
  155. new_data.set_data({f'{k}': data_list})
  156. return new_data
  157. # Tensor-like methods
  158. def cuda(self) -> 'BaseDataElement':
  159. """Convert all tensors to GPU in data."""
  160. new_data = self.new()
  161. for k, v_list in self.items():
  162. data_list = []
  163. for v in v_list:
  164. if isinstance(v, (torch.Tensor, BaseDataElement)):
  165. v = v.cuda()
  166. data_list.append(v)
  167. if len(data_list) > 0:
  168. new_data.set_data({f'{k}': data_list})
  169. return new_data
  170. # Tensor-like methods
  171. def npu(self) -> 'BaseDataElement':
  172. """Convert all tensors to NPU in data."""
  173. new_data = self.new()
  174. for k, v_list in self.items():
  175. data_list = []
  176. for v in v_list:
  177. if isinstance(v, (torch.Tensor, BaseDataElement)):
  178. v = v.npu()
  179. data_list.append(v)
  180. if len(data_list) > 0:
  181. new_data.set_data({f'{k}': data_list})
  182. return new_data
  183. # Tensor-like methods
  184. def detach(self) -> 'BaseDataElement':
  185. """Detach all tensors in data."""
  186. new_data = self.new()
  187. for k, v_list in self.items():
  188. data_list = []
  189. for v in v_list:
  190. if isinstance(v, (torch.Tensor, BaseDataElement)):
  191. v = v.detach()
  192. data_list.append(v)
  193. if len(data_list) > 0:
  194. new_data.set_data({f'{k}': data_list})
  195. return new_data
  196. # Tensor-like methods
  197. def numpy(self) -> 'BaseDataElement':
  198. """Convert all tensors to np.ndarray in data."""
  199. new_data = self.new()
  200. for k, v_list in self.items():
  201. data_list = []
  202. for v in v_list:
  203. if isinstance(v, (torch.Tensor, BaseDataElement)):
  204. v = v.detach().cpu().numpy()
  205. data_list.append(v)
  206. if len(data_list) > 0:
  207. new_data.set_data({f'{k}': data_list})
  208. return new_data
  209. def to_tensor(self) -> 'BaseDataElement':
  210. """Convert all np.ndarray to tensor in data."""
  211. new_data = self.new()
  212. for k, v_list in self.items():
  213. data_list = []
  214. for v in v_list:
  215. if isinstance(v, np.ndarray):
  216. v = torch.from_numpy(v)
  217. elif isinstance(v, BaseDataElement):
  218. v = v.to_tensor()
  219. data_list.append(v)
  220. if len(data_list) > 0:
  221. new_data.set_data({f'{k}': data_list})
  222. return new_data
  223. # Tensor-like methods
  224. def clone(self) -> 'BaseDataElement':
  225. """Deep copy the current data element.
  226. Returns:
  227. BaseDataElement: The copy of current data element.
  228. """
  229. clone_data = self.__class__()
  230. clone_data.set_metainfo(dict(self.metainfo_items()))
  231. for k, v_list in self.items():
  232. clone_item_list = []
  233. for v in v_list:
  234. clone_item_list.append(v.clone())
  235. clone_data.set_data({k: clone_item_list})
  236. return clone_data
  237. TrackSampleList = List[TrackDataSample]
  238. OptTrackSampleList = Optional[TrackSampleList]