interpolation.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. try:
  4. from sklearn.gaussian_process import GaussianProcessRegressor as GPR
  5. from sklearn.gaussian_process.kernels import RBF
  6. HAS_SKIKIT_LEARN = True
  7. except ImportError:
  8. HAS_SKIKIT_LEARN = False
  9. from mmdet.registry import TASK_UTILS
  10. @TASK_UTILS.register_module()
  11. class InterpolateTracklets:
  12. """Interpolate tracks to make tracks more complete.
  13. Args:
  14. min_num_frames (int, optional): The minimum length of a track that will
  15. be interpolated. Defaults to 5.
  16. max_num_frames (int, optional): The maximum disconnected length in
  17. a track. Defaults to 20.
  18. use_gsi (bool, optional): Whether to use the GSI (Gaussian-smoothed
  19. interpolation) method. Defaults to False.
  20. smooth_tau (int, optional): smoothing parameter in GSI. Defaults to 10.
  21. """
  22. def __init__(self,
  23. min_num_frames: int = 5,
  24. max_num_frames: int = 20,
  25. use_gsi: bool = False,
  26. smooth_tau: int = 10):
  27. if not HAS_SKIKIT_LEARN:
  28. raise RuntimeError('sscikit-learn is not installed,\
  29. please install it by: pip install scikit-learn')
  30. self.min_num_frames = min_num_frames
  31. self.max_num_frames = max_num_frames
  32. self.use_gsi = use_gsi
  33. self.smooth_tau = smooth_tau
  34. def _interpolate_track(self,
  35. track: np.ndarray,
  36. track_id: int,
  37. max_num_frames: int = 20) -> np.ndarray:
  38. """Interpolate a track linearly to make the track more complete.
  39. This function is proposed in
  40. "ByteTrack: Multi-Object Tracking by Associating Every Detection Box."
  41. `ByteTrack<https://arxiv.org/abs/2110.06864>`_.
  42. Args:
  43. track (ndarray): With shape (N, 7). Each row denotes
  44. (frame_id, track_id, x1, y1, x2, y2, score).
  45. max_num_frames (int, optional): The maximum disconnected length in
  46. the track. Defaults to 20.
  47. Returns:
  48. ndarray: The interpolated track with shape (N, 7). Each row denotes
  49. (frame_id, track_id, x1, y1, x2, y2, score)
  50. """
  51. assert (track[:, 1] == track_id).all(), \
  52. 'The track id should not changed when interpolate a track.'
  53. frame_ids = track[:, 0]
  54. interpolated_track = np.zeros((0, 7))
  55. # perform interpolation for the disconnected frames in the track.
  56. for i in np.where(np.diff(frame_ids) > 1)[0]:
  57. left_frame_id = frame_ids[i]
  58. right_frame_id = frame_ids[i + 1]
  59. num_disconnected_frames = int(right_frame_id - left_frame_id)
  60. if 1 < num_disconnected_frames < max_num_frames:
  61. left_bbox = track[i, 2:6]
  62. right_bbox = track[i + 1, 2:6]
  63. # perform interpolation for two adjacent tracklets.
  64. for j in range(1, num_disconnected_frames):
  65. cur_bbox = j / (num_disconnected_frames) * (
  66. right_bbox - left_bbox) + left_bbox
  67. cur_result = np.ones((7, ))
  68. cur_result[0] = j + left_frame_id
  69. cur_result[1] = track_id
  70. cur_result[2:6] = cur_bbox
  71. interpolated_track = np.concatenate(
  72. (interpolated_track, cur_result[None]), axis=0)
  73. interpolated_track = np.concatenate((track, interpolated_track),
  74. axis=0)
  75. return interpolated_track
  76. def gaussian_smoothed_interpolation(self,
  77. track: np.ndarray,
  78. smooth_tau: int = 10) -> np.ndarray:
  79. """Gaussian-Smoothed Interpolation.
  80. This function is proposed in
  81. "StrongSORT: Make DeepSORT Great Again"
  82. `StrongSORT<https://arxiv.org/abs/2202.13514>`_.
  83. Args:
  84. track (ndarray): With shape (N, 7). Each row denotes
  85. (frame_id, track_id, x1, y1, x2, y2, score).
  86. smooth_tau (int, optional): smoothing parameter in GSI.
  87. Defaults to 10.
  88. Returns:
  89. ndarray: The interpolated tracks with shape (N, 7). Each row
  90. denotes (frame_id, track_id, x1, y1, x2, y2, score)
  91. """
  92. len_scale = np.clip(smooth_tau * np.log(smooth_tau**3 / len(track)),
  93. smooth_tau**-1, smooth_tau**2)
  94. gpr = GPR(RBF(len_scale, 'fixed'))
  95. t = track[:, 0].reshape(-1, 1)
  96. x1 = track[:, 2].reshape(-1, 1)
  97. y1 = track[:, 3].reshape(-1, 1)
  98. x2 = track[:, 4].reshape(-1, 1)
  99. y2 = track[:, 5].reshape(-1, 1)
  100. gpr.fit(t, x1)
  101. x1_gpr = gpr.predict(t)
  102. gpr.fit(t, y1)
  103. y1_gpr = gpr.predict(t)
  104. gpr.fit(t, x2)
  105. x2_gpr = gpr.predict(t)
  106. gpr.fit(t, y2)
  107. y2_gpr = gpr.predict(t)
  108. gsi_track = [[
  109. t[i, 0], track[i, 1], x1_gpr[i], y1_gpr[i], x2_gpr[i], y2_gpr[i],
  110. track[i, 6]
  111. ] for i in range(len(t))]
  112. return np.array(gsi_track)
  113. def forward(self, pred_tracks: np.ndarray) -> np.ndarray:
  114. """Forward function.
  115. pred_tracks (ndarray): With shape (N, 7). Each row denotes
  116. (frame_id, track_id, x1, y1, x2, y2, score).
  117. Returns:
  118. ndarray: The interpolated tracks with shape (N, 7). Each row
  119. denotes (frame_id, track_id, x1, y1, x2, y2, score).
  120. """
  121. max_track_id = int(np.max(pred_tracks[:, 1]))
  122. min_track_id = int(np.min(pred_tracks[:, 1]))
  123. # perform interpolation for each track
  124. interpolated_tracks = []
  125. for track_id in range(min_track_id, max_track_id + 1):
  126. inds = pred_tracks[:, 1] == track_id
  127. track = pred_tracks[inds]
  128. num_frames = len(track)
  129. if num_frames <= 2:
  130. continue
  131. if num_frames > self.min_num_frames:
  132. interpolated_track = self._interpolate_track(
  133. track, track_id, self.max_num_frames)
  134. else:
  135. interpolated_track = track
  136. if self.use_gsi:
  137. interpolated_track = self.gaussian_smoothed_interpolation(
  138. interpolated_track, self.smooth_tau)
  139. interpolated_tracks.append(interpolated_track)
  140. interpolated_tracks = np.concatenate(interpolated_tracks)
  141. return interpolated_tracks[interpolated_tracks[:, 0].argsort()]