base_tracker.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. from typing import List, Optional, Tuple
  4. import torch
  5. import torch.nn.functional as F
  6. from addict import Dict
  7. class BaseTracker(metaclass=ABCMeta):
  8. """Base tracker model.
  9. Args:
  10. momentums (dict[str:float], optional): Momentums to update the buffers.
  11. The `str` indicates the name of the buffer while the `float`
  12. indicates the momentum. Defaults to None.
  13. num_frames_retain (int, optional). If a track is disappeared more than
  14. `num_frames_retain` frames, it will be deleted in the memo.
  15. Defaults to 10.
  16. """
  17. def __init__(self,
  18. momentums: Optional[dict] = None,
  19. num_frames_retain: int = 10) -> None:
  20. super().__init__()
  21. if momentums is not None:
  22. assert isinstance(momentums, dict), 'momentums must be a dict'
  23. self.momentums = momentums
  24. self.num_frames_retain = num_frames_retain
  25. self.reset()
  26. def reset(self) -> None:
  27. """Reset the buffer of the tracker."""
  28. self.num_tracks = 0
  29. self.tracks = dict()
  30. @property
  31. def empty(self) -> bool:
  32. """Whether the buffer is empty or not."""
  33. return False if self.tracks else True
  34. @property
  35. def ids(self) -> List[dict]:
  36. """All ids in the tracker."""
  37. return list(self.tracks.keys())
  38. @property
  39. def with_reid(self) -> bool:
  40. """bool: whether the framework has a reid model"""
  41. return hasattr(self, 'reid') and self.reid is not None
  42. def update(self, **kwargs) -> None:
  43. """Update the tracker.
  44. Args:
  45. kwargs (dict[str: Tensor | int]): The `str` indicates the
  46. name of the input variable. `ids` and `frame_ids` are
  47. obligatory in the keys.
  48. """
  49. memo_items = [k for k, v in kwargs.items() if v is not None]
  50. rm_items = [k for k in kwargs.keys() if k not in memo_items]
  51. for item in rm_items:
  52. kwargs.pop(item)
  53. if not hasattr(self, 'memo_items'):
  54. self.memo_items = memo_items
  55. else:
  56. assert memo_items == self.memo_items
  57. assert 'ids' in memo_items
  58. num_objs = len(kwargs['ids'])
  59. id_indice = memo_items.index('ids')
  60. assert 'frame_ids' in memo_items
  61. frame_id = int(kwargs['frame_ids'])
  62. if isinstance(kwargs['frame_ids'], int):
  63. kwargs['frame_ids'] = torch.tensor([kwargs['frame_ids']] *
  64. num_objs)
  65. # cur_frame_id = int(kwargs['frame_ids'][0])
  66. for k, v in kwargs.items():
  67. if len(v) != num_objs:
  68. raise ValueError('kwargs value must both equal')
  69. for obj in zip(*kwargs.values()):
  70. id = int(obj[id_indice])
  71. if id in self.tracks:
  72. self.update_track(id, obj)
  73. else:
  74. self.init_track(id, obj)
  75. self.pop_invalid_tracks(frame_id)
  76. def pop_invalid_tracks(self, frame_id: int) -> None:
  77. """Pop out invalid tracks."""
  78. invalid_ids = []
  79. for k, v in self.tracks.items():
  80. if frame_id - v['frame_ids'][-1] >= self.num_frames_retain:
  81. invalid_ids.append(k)
  82. for invalid_id in invalid_ids:
  83. self.tracks.pop(invalid_id)
  84. def update_track(self, id: int, obj: Tuple[torch.Tensor]):
  85. """Update a track."""
  86. for k, v in zip(self.memo_items, obj):
  87. v = v[None]
  88. if self.momentums is not None and k in self.momentums:
  89. m = self.momentums[k]
  90. self.tracks[id][k] = (1 - m) * self.tracks[id][k] + m * v
  91. else:
  92. self.tracks[id][k].append(v)
  93. def init_track(self, id: int, obj: Tuple[torch.Tensor]):
  94. """Initialize a track."""
  95. self.tracks[id] = Dict()
  96. for k, v in zip(self.memo_items, obj):
  97. v = v[None]
  98. if self.momentums is not None and k in self.momentums:
  99. self.tracks[id][k] = v
  100. else:
  101. self.tracks[id][k] = [v]
  102. @property
  103. def memo(self) -> dict:
  104. """Return all buffers in the tracker."""
  105. outs = Dict()
  106. for k in self.memo_items:
  107. outs[k] = []
  108. for id, objs in self.tracks.items():
  109. for k, v in objs.items():
  110. if k not in outs:
  111. continue
  112. if self.momentums is not None and k in self.momentums:
  113. v = v
  114. else:
  115. v = v[-1]
  116. outs[k].append(v)
  117. for k, v in outs.items():
  118. outs[k] = torch.cat(v, dim=0)
  119. return outs
  120. def get(self,
  121. item: str,
  122. ids: Optional[list] = None,
  123. num_samples: Optional[int] = None,
  124. behavior: Optional[str] = None) -> torch.Tensor:
  125. """Get the buffer of a specific item.
  126. Args:
  127. item (str): The demanded item.
  128. ids (list[int], optional): The demanded ids. Defaults to None.
  129. num_samples (int, optional): Number of samples to calculate the
  130. results. Defaults to None.
  131. behavior (str, optional): Behavior to calculate the results.
  132. Options are `mean` | None. Defaults to None.
  133. Returns:
  134. Tensor: The results of the demanded item.
  135. """
  136. if ids is None:
  137. ids = self.ids
  138. outs = []
  139. for id in ids:
  140. out = self.tracks[id][item]
  141. if isinstance(out, list):
  142. if num_samples is not None:
  143. out = out[-num_samples:]
  144. out = torch.cat(out, dim=0)
  145. if behavior == 'mean':
  146. out = out.mean(dim=0, keepdim=True)
  147. elif behavior is None:
  148. out = out[None]
  149. else:
  150. raise NotImplementedError()
  151. else:
  152. out = out[-1]
  153. outs.append(out)
  154. return torch.cat(outs, dim=0)
  155. @abstractmethod
  156. def track(self, *args, **kwargs):
  157. """Tracking forward function."""
  158. pass
  159. def crop_imgs(self,
  160. img: torch.Tensor,
  161. meta_info: dict,
  162. bboxes: torch.Tensor,
  163. rescale: bool = False) -> torch.Tensor:
  164. """Crop the images according to some bounding boxes. Typically for re-
  165. identification sub-module.
  166. Args:
  167. img (Tensor): of shape (T, C, H, W) encoding input image.
  168. Typically these should be mean centered and std scaled.
  169. meta_info (dict): image information dict where each dict
  170. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  171. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  172. bboxes (Tensor): of shape (N, 4) or (N, 5).
  173. rescale (bool, optional): If True, the bounding boxes should be
  174. rescaled to fit the scale of the image. Defaults to False.
  175. Returns:
  176. Tensor: Image tensor of shape (T, C, H, W).
  177. """
  178. h, w = meta_info['img_shape']
  179. img = img[:, :, :h, :w]
  180. if rescale:
  181. factor_x, factor_y = meta_info['scale_factor']
  182. bboxes[:, :4] *= torch.tensor(
  183. [factor_x, factor_y, factor_x, factor_y]).to(bboxes.device)
  184. bboxes[:, 0] = torch.clamp(bboxes[:, 0], min=0, max=w - 1)
  185. bboxes[:, 1] = torch.clamp(bboxes[:, 1], min=0, max=h - 1)
  186. bboxes[:, 2] = torch.clamp(bboxes[:, 2], min=1, max=w)
  187. bboxes[:, 3] = torch.clamp(bboxes[:, 3], min=1, max=h)
  188. crop_imgs = []
  189. for bbox in bboxes:
  190. x1, y1, x2, y2 = map(int, bbox)
  191. if x2 <= x1:
  192. x2 = x1 + 1
  193. if y2 <= y1:
  194. y2 = y1 + 1
  195. crop_img = img[:, :, y1:y2, x1:x2]
  196. if self.reid.get('img_scale', False):
  197. crop_img = F.interpolate(
  198. crop_img,
  199. size=self.reid['img_scale'],
  200. mode='bilinear',
  201. align_corners=False)
  202. crop_imgs.append(crop_img)
  203. if len(crop_imgs) > 0:
  204. return torch.cat(crop_imgs, dim=0)
  205. elif self.reid.get('img_scale', False):
  206. _h, _w = self.reid['img_scale']
  207. return img.new_zeros((0, 3, _h, _w))
  208. else:
  209. return img.new_zeros((0, 3, h, w))