aflink.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from collections import defaultdict
  3. from typing import Tuple
  4. import numpy as np
  5. import torch
  6. from mmengine.model import BaseModule
  7. from mmengine.runner.checkpoint import load_checkpoint
  8. from scipy.optimize import linear_sum_assignment
  9. from torch import Tensor, nn
  10. from mmdet.registry import TASK_UTILS
  11. INFINITY = 1e5
  12. class TemporalBlock(BaseModule):
  13. """The temporal block of AFLink model.
  14. Args:
  15. in_channel (int): the dimension of the input channels.
  16. out_channel (int): the dimension of the output channels.
  17. """
  18. def __init__(self,
  19. in_channel: int,
  20. out_channel: int,
  21. kernel_size: tuple = (7, 1)):
  22. super(TemporalBlock, self).__init__()
  23. self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, bias=False)
  24. self.relu = nn.ReLU(inplace=True)
  25. self.bnf = nn.BatchNorm1d(out_channel)
  26. self.bnx = nn.BatchNorm1d(out_channel)
  27. self.bny = nn.BatchNorm1d(out_channel)
  28. def bn(self, x: Tensor) -> Tensor:
  29. x[:, :, :, 0] = self.bnf(x[:, :, :, 0])
  30. x[:, :, :, 1] = self.bnx(x[:, :, :, 1])
  31. x[:, :, :, 2] = self.bny(x[:, :, :, 2])
  32. return x
  33. def forward(self, x: Tensor) -> Tensor:
  34. x = self.conv(x)
  35. x = self.bn(x)
  36. x = self.relu(x)
  37. return x
  38. class FusionBlock(BaseModule):
  39. """The fusion block of AFLink model.
  40. Args:
  41. in_channel (int): the dimension of the input channels.
  42. out_channel (int): the dimension of the output channels.
  43. """
  44. def __init__(self, in_channel: int, out_channel: int):
  45. super(FusionBlock, self).__init__()
  46. self.conv = nn.Conv2d(in_channel, out_channel, (1, 3), bias=False)
  47. self.bn = nn.BatchNorm2d(out_channel)
  48. self.relu = nn.ReLU(inplace=True)
  49. def forward(self, x: Tensor) -> Tensor:
  50. x = self.conv(x)
  51. x = self.bn(x)
  52. x = self.relu(x)
  53. return x
  54. class Classifier(BaseModule):
  55. """The classifier of AFLink model.
  56. Args:
  57. in_channel (int): the dimension of the input channels.
  58. """
  59. def __init__(self, in_channel: int, out_channel: int):
  60. super(Classifier, self).__init__()
  61. self.fc1 = nn.Linear(in_channel * 2, in_channel // 2)
  62. self.relu = nn.ReLU(inplace=True)
  63. self.fc2 = nn.Linear(in_channel // 2, out_channel)
  64. def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
  65. x = torch.cat((x1, x2), dim=1)
  66. x = self.fc1(x)
  67. x = self.relu(x)
  68. x = self.fc2(x)
  69. return x
  70. class AFLinkModel(BaseModule):
  71. """Appearance-Free Link Model."""
  72. def __init__(self,
  73. temporal_module_channels: list = [1, 32, 64, 128, 256],
  74. fusion_module_channels: list = [256, 256],
  75. classifier_channels: list = [256, 2]):
  76. super(AFLinkModel, self).__init__()
  77. self.TemporalModule_1 = nn.Sequential(*[
  78. TemporalBlock(temporal_module_channels[i],
  79. temporal_module_channels[i + 1])
  80. for i in range(len(temporal_module_channels) - 1)
  81. ])
  82. self.TemporalModule_2 = nn.Sequential(*[
  83. TemporalBlock(temporal_module_channels[i],
  84. temporal_module_channels[i + 1])
  85. for i in range(len(temporal_module_channels) - 1)
  86. ])
  87. self.FusionBlock_1 = FusionBlock(*fusion_module_channels)
  88. self.FusionBlock_2 = FusionBlock(*fusion_module_channels)
  89. self.pooling = nn.AdaptiveAvgPool2d((1, 1))
  90. self.classifier = Classifier(*classifier_channels)
  91. def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
  92. assert not self.training, 'Only testing is supported for AFLink.'
  93. x1 = x1[:, :, :, :3]
  94. x2 = x2[:, :, :, :3]
  95. x1 = self.TemporalModule_1(x1) # [B,1,30,3] -> [B,256,6,3]
  96. x2 = self.TemporalModule_2(x2)
  97. x1 = self.FusionBlock_1(x1)
  98. x2 = self.FusionBlock_2(x2)
  99. x1 = self.pooling(x1).squeeze(-1).squeeze(-1)
  100. x2 = self.pooling(x2).squeeze(-1).squeeze(-1)
  101. y = self.classifier(x1, x2)
  102. y = torch.softmax(y, dim=1)[0, 1]
  103. return y
  104. @TASK_UTILS.register_module()
  105. class AppearanceFreeLink(BaseModule):
  106. """Appearance-Free Link method.
  107. This method is proposed in
  108. "StrongSORT: Make DeepSORT Great Again"
  109. `StrongSORT<https://arxiv.org/abs/2202.13514>`_.
  110. Args:
  111. checkpoint (str): Checkpoint path.
  112. temporal_threshold (tuple, optional): The temporal constraint
  113. for tracklets association. Defaults to (0, 30).
  114. spatial_threshold (int, optional): The spatial constraint for
  115. tracklets association. Defaults to 75.
  116. confidence_threshold (float, optional): The minimum confidence
  117. threshold for tracklets association. Defaults to 0.95.
  118. """
  119. def __init__(self,
  120. checkpoint: str,
  121. temporal_threshold: tuple = (0, 30),
  122. spatial_threshold: int = 75,
  123. confidence_threshold: float = 0.95):
  124. super(AppearanceFreeLink, self).__init__()
  125. self.temporal_threshold = temporal_threshold
  126. self.spatial_threshold = spatial_threshold
  127. self.confidence_threshold = confidence_threshold
  128. self.model = AFLinkModel()
  129. if checkpoint:
  130. load_checkpoint(self.model, checkpoint)
  131. if torch.cuda.is_available():
  132. self.model.cuda()
  133. self.model.eval()
  134. self.device = next(self.model.parameters()).device
  135. self.fn_l2 = lambda x, y: np.sqrt(x**2 + y**2)
  136. def data_transform(self,
  137. track1: np.ndarray,
  138. track2: np.ndarray,
  139. length: int = 30) -> Tuple[np.ndarray]:
  140. """Data Transformation. This is used to standardize the length of
  141. tracks to a unified length. Then perform min-max normalization to the
  142. motion embeddings.
  143. Args:
  144. track1 (ndarray): the first track with shape (N,C).
  145. track2 (ndarray): the second track with shape (M,C).
  146. length (int): the unified length of tracks. Defaults to 30.
  147. Returns:
  148. Tuple[ndarray]: the transformed track1 and track2.
  149. """
  150. # fill or cut track1
  151. length_1 = track1.shape[0]
  152. track1 = track1[-length:] if length_1 >= length else \
  153. np.pad(track1, ((length - length_1, 0), (0, 0)))
  154. # fill or cut track1
  155. length_2 = track2.shape[0]
  156. track2 = track2[:length] if length_2 >= length else \
  157. np.pad(track2, ((0, length - length_2), (0, 0)))
  158. # min-max normalization
  159. min_ = np.concatenate((track1, track2), axis=0).min(axis=0)
  160. max_ = np.concatenate((track1, track2), axis=0).max(axis=0)
  161. subtractor = (max_ + min_) / 2
  162. divisor = (max_ - min_) / 2 + 1e-5
  163. track1 = (track1 - subtractor) / divisor
  164. track2 = (track2 - subtractor) / divisor
  165. return track1, track2
  166. def forward(self, pred_tracks: np.ndarray) -> np.ndarray:
  167. """Forward function.
  168. pred_tracks (ndarray): With shape (N, 7). Each row denotes
  169. (frame_id, track_id, x1, y1, x2, y2, score).
  170. Returns:
  171. ndarray: The linked tracks with shape (N, 7). Each row denotes
  172. (frame_id, track_id, x1, y1, x2, y2, score)
  173. """
  174. # sort tracks by the frame id
  175. pred_tracks = pred_tracks[np.argsort(pred_tracks[:, 0])]
  176. # gather tracks information
  177. id2info = defaultdict(list)
  178. for row in pred_tracks:
  179. frame_id, track_id, x1, y1, x2, y2 = row[:6]
  180. id2info[track_id].append([frame_id, x1, y1, x2 - x1, y2 - y1])
  181. id2info = {k: np.array(v) for k, v in id2info.items()}
  182. num_track = len(id2info)
  183. track_ids = np.array(list(id2info))
  184. cost_matrix = np.full((num_track, num_track), INFINITY)
  185. # compute the cost matrix
  186. for i, id_i in enumerate(track_ids):
  187. for j, id_j in enumerate(track_ids):
  188. if id_i == id_j:
  189. continue
  190. info_i, info_j = id2info[id_i], id2info[id_j]
  191. frame_i, box_i = info_i[-1][0], info_i[-1][1:3]
  192. frame_j, box_j = info_j[0][0], info_j[0][1:3]
  193. # temporal constraint
  194. if not self.temporal_threshold[0] <= \
  195. frame_j - frame_i <= self.temporal_threshold[1]:
  196. continue
  197. # spatial constraint
  198. if self.fn_l2(box_i[0] - box_j[0], box_i[1] - box_j[1]) \
  199. > self.spatial_threshold:
  200. continue
  201. # confidence constraint
  202. track_i, track_j = self.data_transform(info_i, info_j)
  203. # numpy to torch
  204. track_i = torch.tensor(
  205. track_i, dtype=torch.float).to(self.device)
  206. track_j = torch.tensor(
  207. track_j, dtype=torch.float).to(self.device)
  208. track_i = track_i.unsqueeze(0).unsqueeze(0)
  209. track_j = track_j.unsqueeze(0).unsqueeze(0)
  210. confidence = self.model(track_i,
  211. track_j).detach().cpu().numpy()
  212. if confidence >= self.confidence_threshold:
  213. cost_matrix[i, j] = 1 - confidence
  214. # linear assignment
  215. indices = linear_sum_assignment(cost_matrix)
  216. _id2id = dict() # the temporary assignment results
  217. id2id = dict() # the final assignment results
  218. for i, j in zip(indices[0], indices[1]):
  219. if cost_matrix[i, j] < INFINITY:
  220. _id2id[i] = j
  221. for k, v in _id2id.items():
  222. if k in id2id:
  223. id2id[v] = id2id[k]
  224. else:
  225. id2id[v] = k
  226. # link
  227. for k, v in id2id.items():
  228. pred_tracks[pred_tracks[:, 1] == k, 1] = v
  229. # deduplicate
  230. _, index = np.unique(pred_tracks[:, :2], return_index=True, axis=0)
  231. return pred_tracks[index]