camera_motion_compensation.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import cv2
  3. import numpy as np
  4. import torch
  5. from torch import Tensor
  6. from mmdet.registry import TASK_UTILS
  7. from mmdet.structures.bbox import bbox_cxcyah_to_xyxy, bbox_xyxy_to_cxcyah
  8. @TASK_UTILS.register_module()
  9. class CameraMotionCompensation:
  10. """Camera motion compensation.
  11. Args:
  12. warp_mode (str): Warp mode in opencv.
  13. Defaults to 'cv2.MOTION_EUCLIDEAN'.
  14. num_iters (int): Number of the iterations. Defaults to 50.
  15. stop_eps (float): Terminate threshold. Defaults to 0.001.
  16. """
  17. def __init__(self,
  18. warp_mode: str = 'cv2.MOTION_EUCLIDEAN',
  19. num_iters: int = 50,
  20. stop_eps: float = 0.001):
  21. self.warp_mode = eval(warp_mode)
  22. self.num_iters = num_iters
  23. self.stop_eps = stop_eps
  24. def get_warp_matrix(self, img: np.ndarray, ref_img: np.ndarray) -> Tensor:
  25. """Calculate warping matrix between two images."""
  26. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  27. ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2GRAY)
  28. warp_matrix = np.eye(2, 3, dtype=np.float32)
  29. criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
  30. self.num_iters, self.stop_eps)
  31. cc, warp_matrix = cv2.findTransformECC(img, ref_img, warp_matrix,
  32. self.warp_mode, criteria, None,
  33. 1)
  34. warp_matrix = torch.from_numpy(warp_matrix)
  35. return warp_matrix
  36. def warp_bboxes(self, bboxes: Tensor, warp_matrix: Tensor) -> Tensor:
  37. """Warp bounding boxes according to the warping matrix."""
  38. tl, br = bboxes[:, :2], bboxes[:, 2:]
  39. tl = torch.cat((tl, torch.ones(tl.shape[0], 1).to(bboxes.device)),
  40. dim=1)
  41. br = torch.cat((br, torch.ones(tl.shape[0], 1).to(bboxes.device)),
  42. dim=1)
  43. trans_tl = torch.mm(warp_matrix, tl.t()).t()
  44. trans_br = torch.mm(warp_matrix, br.t()).t()
  45. trans_bboxes = torch.cat((trans_tl, trans_br), dim=1)
  46. return trans_bboxes.to(bboxes.device)
  47. def warp_means(self, means: np.ndarray, warp_matrix: Tensor) -> np.ndarray:
  48. """Warp track.mean according to the warping matrix."""
  49. cxcyah = torch.from_numpy(means[:, :4]).float()
  50. xyxy = bbox_cxcyah_to_xyxy(cxcyah)
  51. warped_xyxy = self.warp_bboxes(xyxy, warp_matrix)
  52. warped_cxcyah = bbox_xyxy_to_cxcyah(warped_xyxy).numpy()
  53. means[:, :4] = warped_cxcyah
  54. return means
  55. def track(self, img: Tensor, ref_img: Tensor, tracks: dict,
  56. num_samples: int, frame_id: int, metainfo: dict) -> dict:
  57. """Tracking forward."""
  58. img = img.squeeze(0).cpu().numpy().transpose((1, 2, 0))
  59. ref_img = ref_img.squeeze(0).cpu().numpy().transpose((1, 2, 0))
  60. warp_matrix = self.get_warp_matrix(img, ref_img)
  61. # rescale the warp_matrix due to the `resize` in pipeline
  62. scale_factor_h, scale_factor_w = metainfo['scale_factor']
  63. warp_matrix[0, 2] = warp_matrix[0, 2] / scale_factor_w
  64. warp_matrix[1, 2] = warp_matrix[1, 2] / scale_factor_h
  65. bboxes = []
  66. num_bboxes = []
  67. means = []
  68. for k, v in tracks.items():
  69. if int(v['frame_ids'][-1]) < frame_id - 1:
  70. _num = 1
  71. else:
  72. _num = min(num_samples, len(v.bboxes))
  73. num_bboxes.append(_num)
  74. bboxes.extend(v.bboxes[-_num:])
  75. if len(v.mean) > 0:
  76. means.append(v.mean)
  77. bboxes = torch.cat(bboxes, dim=0)
  78. warped_bboxes = self.warp_bboxes(bboxes, warp_matrix.to(bboxes.device))
  79. warped_bboxes = torch.split(warped_bboxes, num_bboxes)
  80. for b, (k, v) in zip(warped_bboxes, tracks.items()):
  81. _num = b.shape[0]
  82. b = torch.split(b, [1] * _num)
  83. tracks[k].bboxes[-_num:] = b
  84. if means:
  85. means = np.asarray(means)
  86. warped_means = self.warp_means(means, warp_matrix)
  87. for m, (k, v) in zip(warped_means, tracks.items()):
  88. tracks[k].mean = m
  89. return tracks