masktrack_rcnn.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. from torch import Tensor
  5. from mmdet.models.mot import BaseMOTModel
  6. from mmdet.registry import MODELS
  7. from mmdet.structures import TrackSampleList
  8. from mmdet.utils import OptConfigType, OptMultiConfig
  9. @MODELS.register_module()
  10. class MaskTrackRCNN(BaseMOTModel):
  11. """Video Instance Segmentation.
  12. This video instance segmentor is the implementation of`MaskTrack R-CNN
  13. <https://arxiv.org/abs/1905.04804>`_.
  14. Args:
  15. detector (dict): Configuration of detector. Defaults to None.
  16. track_head (dict): Configuration of track head. Defaults to None.
  17. tracker (dict): Configuration of tracker. Defaults to None.
  18. data_preprocessor (dict or ConfigDict, optional): The pre-process
  19. config of :class:`TrackDataPreprocessor`. it usually includes,
  20. ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
  21. init_cfg (dict or list[dict]): Configuration of initialization.
  22. Defaults to None.
  23. """
  24. def __init__(self,
  25. detector: Optional[dict] = None,
  26. track_head: Optional[dict] = None,
  27. tracker: Optional[dict] = None,
  28. data_preprocessor: OptConfigType = None,
  29. init_cfg: OptMultiConfig = None):
  30. super().__init__(data_preprocessor, init_cfg)
  31. if detector is not None:
  32. self.detector = MODELS.build(detector)
  33. assert hasattr(self.detector, 'roi_head'), \
  34. 'MaskTrack R-CNN only supports two stage detectors.'
  35. if track_head is not None:
  36. self.track_head = MODELS.build(track_head)
  37. if tracker is not None:
  38. self.tracker = MODELS.build(tracker)
  39. def loss(self, inputs: Tensor, data_samples: TrackSampleList,
  40. **kwargs) -> dict:
  41. """Calculate losses from a batch of inputs and data samples.
  42. Args:
  43. inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding
  44. input images. Typically these should be mean centered and std
  45. scaled. The N denotes batch size. The T denotes the number of
  46. frames.
  47. data_samples (list[:obj:`TrackDataSample`]): The batch
  48. data samples. It usually includes information such
  49. as `gt_instance`.
  50. Returns:
  51. dict: A dictionary of loss components.
  52. """
  53. assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
  54. assert inputs.size(1) == 2, \
  55. 'MaskTrackRCNN can only have 1 key frame and 1 reference frame.'
  56. # split the data_samples into two aspects: key frames and reference
  57. # frames
  58. ref_data_samples, key_data_samples = [], []
  59. key_frame_inds, ref_frame_inds = [], []
  60. # set cat_id of gt_labels to 0 in RPN
  61. for track_data_sample in data_samples:
  62. key_data_sample = track_data_sample.get_key_frames()[0]
  63. key_data_samples.append(key_data_sample)
  64. ref_data_sample = track_data_sample.get_ref_frames()[0]
  65. ref_data_samples.append(ref_data_sample)
  66. key_frame_inds.append(track_data_sample.key_frames_inds[0])
  67. ref_frame_inds.append(track_data_sample.ref_frames_inds[0])
  68. key_frame_inds = torch.tensor(key_frame_inds, dtype=torch.int64)
  69. ref_frame_inds = torch.tensor(ref_frame_inds, dtype=torch.int64)
  70. batch_inds = torch.arange(len(inputs))
  71. key_imgs = inputs[batch_inds, key_frame_inds].contiguous()
  72. ref_imgs = inputs[batch_inds, ref_frame_inds].contiguous()
  73. x = self.detector.extract_feat(key_imgs)
  74. ref_x = self.detector.extract_feat(ref_imgs)
  75. losses = dict()
  76. # RPN forward and loss
  77. if self.detector.with_rpn:
  78. proposal_cfg = self.detector.train_cfg.get(
  79. 'rpn_proposal', self.detector.test_cfg.rpn)
  80. rpn_losses, rpn_results_list = self.detector.rpn_head. \
  81. loss_and_predict(x,
  82. key_data_samples,
  83. proposal_cfg=proposal_cfg,
  84. **kwargs)
  85. # avoid get same name with roi_head loss
  86. keys = rpn_losses.keys()
  87. for key in keys:
  88. if 'loss' in key and 'rpn' not in key:
  89. rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
  90. losses.update(rpn_losses)
  91. else:
  92. # TODO: Not support currently, should have a check at Fast R-CNN
  93. assert key_data_samples[0].get('proposals', None) is not None
  94. # use pre-defined proposals in InstanceData for the second stage
  95. # to extract ROI features.
  96. rpn_results_list = [
  97. key_data_sample.proposals
  98. for key_data_sample in key_data_samples
  99. ]
  100. losses_detect = self.detector.roi_head.loss(x, rpn_results_list,
  101. key_data_samples, **kwargs)
  102. losses.update(losses_detect)
  103. losses_track = self.track_head.loss(x, ref_x, rpn_results_list,
  104. data_samples, **kwargs)
  105. losses.update(losses_track)
  106. return losses
  107. def predict(self,
  108. inputs: Tensor,
  109. data_samples: TrackSampleList,
  110. rescale: bool = True,
  111. **kwargs) -> TrackSampleList:
  112. """Test without augmentation.
  113. Args:
  114. inputs (Tensor): of shape (N, T, C, H, W) encoding
  115. input images. The N denotes batch size.
  116. The T denotes the number of frames in a video.
  117. data_samples (list[:obj:`TrackDataSample`]): The batch
  118. data samples. It usually includes information such
  119. as `video_data_samples`.
  120. rescale (bool, Optional): If False, then returned bboxes and masks
  121. will fit the scale of img, otherwise, returned bboxes and masks
  122. will fit the scale of original image shape. Defaults to True.
  123. Returns:
  124. TrackSampleList: Tracking results of the inputs.
  125. """
  126. assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
  127. assert len(data_samples) == 1, \
  128. 'MaskTrackRCNN only support 1 batch size per gpu for now.'
  129. track_data_sample = data_samples[0]
  130. video_len = len(track_data_sample)
  131. if track_data_sample[0].frame_id == 0:
  132. self.tracker.reset()
  133. for frame_id in range(video_len):
  134. img_data_sample = track_data_sample[frame_id]
  135. single_img = inputs[:, frame_id].contiguous()
  136. x = self.detector.extract_feat(single_img)
  137. rpn_results_list = self.detector.rpn_head.predict(
  138. x, [img_data_sample])
  139. # det_results List[InstanceData]
  140. det_results = self.detector.roi_head.predict(
  141. x, rpn_results_list, [img_data_sample], rescale=rescale)
  142. assert len(det_results) == 1, 'Batch inference is not supported.'
  143. assert 'masks' in det_results[0], 'There are no mask results.'
  144. img_data_sample.pred_instances = det_results[0]
  145. frame_pred_track_instances = self.tracker.track(
  146. model=self, feats=x, data_sample=img_data_sample, **kwargs)
  147. img_data_sample.pred_track_instances = frame_pred_track_instances
  148. return [track_data_sample]