masktrack_rcnn_tracker.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. from mmdet.structures import DetDataSample
  8. from mmdet.structures.bbox import bbox_overlaps
  9. from .base_tracker import BaseTracker
  10. @MODELS.register_module()
  11. class MaskTrackRCNNTracker(BaseTracker):
  12. """Tracker for MaskTrack R-CNN.
  13. Args:
  14. match_weights (dict[str : float]): The Weighting factor when computing
  15. the match score. It contains keys as follows:
  16. - det_score (float): The coefficient of `det_score` when computing
  17. match score.
  18. - iou (float): The coefficient of `ious` when computing match
  19. score.
  20. - det_label (float): The coefficient of `label_deltas` when
  21. computing match score.
  22. """
  23. def __init__(self,
  24. match_weights: dict = dict(
  25. det_score=1.0, iou=2.0, det_label=10.0),
  26. **kwargs):
  27. super().__init__(**kwargs)
  28. self.match_weights = match_weights
  29. def get_match_score(self, bboxes: Tensor, labels: Tensor, scores: Tensor,
  30. prev_bboxes: Tensor, prev_labels: Tensor,
  31. similarity_logits: Tensor) -> Tensor:
  32. """Get the match score.
  33. Args:
  34. bboxes (torch.Tensor): of shape (num_current_bboxes, 4) in
  35. [tl_x, tl_y, br_x, br_y] format. Denoting the detection
  36. bboxes of current frame.
  37. labels (torch.Tensor): of shape (num_current_bboxes, )
  38. scores (torch.Tensor): of shape (num_current_bboxes, )
  39. prev_bboxes (torch.Tensor): of shape (num_previous_bboxes, 4) in
  40. [tl_x, tl_y, br_x, br_y] format. Denoting the detection bboxes
  41. of previous frame.
  42. prev_labels (torch.Tensor): of shape (num_previous_bboxes, )
  43. similarity_logits (torch.Tensor): of shape (num_current_bboxes,
  44. num_previous_bboxes + 1). Denoting the similarity logits from
  45. track head.
  46. Returns:
  47. torch.Tensor: The matching score of shape (num_current_bboxes,
  48. num_previous_bboxes + 1)
  49. """
  50. similarity_scores = similarity_logits.softmax(dim=1)
  51. ious = bbox_overlaps(bboxes, prev_bboxes)
  52. iou_dummy = ious.new_zeros(ious.shape[0], 1)
  53. ious = torch.cat((iou_dummy, ious), dim=1)
  54. label_deltas = (labels.view(-1, 1) == prev_labels).float()
  55. label_deltas_dummy = label_deltas.new_ones(label_deltas.shape[0], 1)
  56. label_deltas = torch.cat((label_deltas_dummy, label_deltas), dim=1)
  57. match_score = similarity_scores.log()
  58. match_score += self.match_weights['det_score'] * \
  59. scores.view(-1, 1).log()
  60. match_score += self.match_weights['iou'] * ious
  61. match_score += self.match_weights['det_label'] * label_deltas
  62. return match_score
  63. def assign_ids(self, match_scores: Tensor):
  64. num_prev_bboxes = match_scores.shape[1] - 1
  65. _, match_ids = match_scores.max(dim=1)
  66. ids = match_ids.new_zeros(match_ids.shape[0]) - 1
  67. best_match_scores = match_scores.new_zeros(num_prev_bboxes) - 1e6
  68. for idx, match_id in enumerate(match_ids):
  69. if match_id == 0:
  70. ids[idx] = self.num_tracks
  71. self.num_tracks += 1
  72. else:
  73. match_score = match_scores[idx, match_id]
  74. # TODO: fix the bug where multiple candidate might match
  75. # with the same previous object.
  76. if match_score > best_match_scores[match_id - 1]:
  77. ids[idx] = self.ids[match_id - 1]
  78. best_match_scores[match_id - 1] = match_score
  79. return ids, best_match_scores
  80. def track(self,
  81. model: torch.nn.Module,
  82. feats: List[torch.Tensor],
  83. data_sample: DetDataSample,
  84. rescale=True,
  85. **kwargs) -> InstanceData:
  86. """Tracking forward function.
  87. Args:
  88. model (nn.Module): VIS model.
  89. img (Tensor): of shape (T, C, H, W) encoding input image.
  90. Typically these should be mean centered and std scaled.
  91. The T denotes the number of key images and usually is 1 in
  92. MaskTrackRCNN method.
  93. feats (list[Tensor]): Multi level feature maps of `img`.
  94. data_sample (:obj:`TrackDataSample`): The data sample.
  95. It includes information such as `pred_det_instances`.
  96. rescale (bool, optional): If True, the bounding boxes should be
  97. rescaled to fit the original scale of the image. Defaults to
  98. True.
  99. Returns:
  100. :obj:`InstanceData`: Tracking results of the input images.
  101. Each InstanceData usually contains ``bboxes``, ``labels``,
  102. ``scores`` and ``instances_id``.
  103. """
  104. metainfo = data_sample.metainfo
  105. bboxes = data_sample.pred_instances.bboxes
  106. masks = data_sample.pred_instances.masks
  107. labels = data_sample.pred_instances.labels
  108. scores = data_sample.pred_instances.scores
  109. frame_id = metainfo.get('frame_id', -1)
  110. # create pred_track_instances
  111. pred_track_instances = InstanceData()
  112. if bboxes.shape[0] == 0:
  113. ids = torch.zeros_like(labels)
  114. pred_track_instances = data_sample.pred_instances.clone()
  115. pred_track_instances.instances_id = ids
  116. return pred_track_instances
  117. rescaled_bboxes = bboxes.clone()
  118. if rescale:
  119. scale_factor = rescaled_bboxes.new_tensor(
  120. metainfo['scale_factor']).repeat((1, 2))
  121. rescaled_bboxes = rescaled_bboxes * scale_factor
  122. roi_feats, _ = model.track_head.extract_roi_feats(
  123. feats, [rescaled_bboxes])
  124. if self.empty:
  125. num_new_tracks = bboxes.size(0)
  126. ids = torch.arange(
  127. self.num_tracks,
  128. self.num_tracks + num_new_tracks,
  129. dtype=torch.long)
  130. self.num_tracks += num_new_tracks
  131. else:
  132. prev_bboxes = self.get('bboxes')
  133. prev_labels = self.get('labels')
  134. prev_roi_feats = self.get('roi_feats')
  135. similarity_logits = model.track_head.predict(
  136. roi_feats, prev_roi_feats)
  137. match_scores = self.get_match_score(bboxes, labels, scores,
  138. prev_bboxes, prev_labels,
  139. similarity_logits)
  140. ids, _ = self.assign_ids(match_scores)
  141. valid_inds = ids > -1
  142. ids = ids[valid_inds]
  143. bboxes = bboxes[valid_inds]
  144. labels = labels[valid_inds]
  145. scores = scores[valid_inds]
  146. masks = masks[valid_inds]
  147. roi_feats = roi_feats[valid_inds]
  148. self.update(
  149. ids=ids,
  150. bboxes=bboxes,
  151. labels=labels,
  152. scores=scores,
  153. masks=masks,
  154. roi_feats=roi_feats,
  155. frame_ids=frame_id)
  156. # update pred_track_instances
  157. pred_track_instances.bboxes = bboxes
  158. pred_track_instances.masks = masks
  159. pred_track_instances.labels = labels
  160. pred_track_instances.scores = scores
  161. pred_track_instances.instances_id = ids
  162. return pred_track_instances