mask2former_vis.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Union
  3. from torch import Tensor
  4. from mmdet.models.mot import BaseMOTModel
  5. from mmdet.registry import MODELS
  6. from mmdet.structures import TrackDataSample, TrackSampleList
  7. from mmdet.utils import OptConfigType, OptMultiConfig
  8. @MODELS.register_module()
  9. class Mask2FormerVideo(BaseMOTModel):
  10. r"""Implementation of `Masked-attention Mask
  11. Transformer for Universal Image Segmentation
  12. <https://arxiv.org/pdf/2112.01527>`_.
  13. Args:
  14. backbone (dict): Configuration of backbone. Defaults to None.
  15. track_head (dict): Configuration of track head. Defaults to None.
  16. data_preprocessor (dict or ConfigDict, optional): The pre-process
  17. config of :class:`TrackDataPreprocessor`. it usually includes,
  18. ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
  19. Defaults to None.
  20. init_cfg (dict or list[dict]): Configuration of initialization.
  21. Defaults to None.
  22. """
  23. def __init__(self,
  24. backbone: Optional[dict] = None,
  25. track_head: Optional[dict] = None,
  26. data_preprocessor: OptConfigType = None,
  27. init_cfg: OptMultiConfig = None):
  28. super(BaseMOTModel, self).__init__(
  29. data_preprocessor=data_preprocessor, init_cfg=init_cfg)
  30. if backbone is not None:
  31. self.backbone = MODELS.build(backbone)
  32. if track_head is not None:
  33. self.track_head = MODELS.build(track_head)
  34. self.num_classes = self.track_head.num_classes
  35. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  36. missing_keys, unexpected_keys, error_msgs):
  37. """Overload in order to load mmdet pretrained ckpt."""
  38. for key in list(state_dict):
  39. if key.startswith('panoptic_head'):
  40. state_dict[key.replace('panoptic',
  41. 'track')] = state_dict.pop(key)
  42. super()._load_from_state_dict(state_dict, prefix, local_metadata,
  43. strict, missing_keys, unexpected_keys,
  44. error_msgs)
  45. def loss(self, inputs: Tensor, data_samples: TrackSampleList,
  46. **kwargs) -> Union[dict, tuple]:
  47. """
  48. Args:
  49. inputs (Tensor): Input images of shape (N, T, C, H, W).
  50. These should usually be mean centered and std scaled.
  51. data_samples (list[:obj:`TrackDataSample`]): The batch
  52. data samples. It usually includes information such
  53. as `gt_instance`.
  54. Returns:
  55. dict[str, Tensor]: a dictionary of loss components
  56. """
  57. assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
  58. # shape (N * T, C, H, W)
  59. img = inputs.flatten(0, 1)
  60. x = self.backbone(img)
  61. losses = self.track_head.loss(x, data_samples)
  62. return losses
  63. def predict(self,
  64. inputs: Tensor,
  65. data_samples: TrackSampleList,
  66. rescale: bool = True) -> TrackSampleList:
  67. """Predict results from a batch of inputs and data samples with
  68. postprocessing.
  69. Args:
  70. inputs (Tensor): of shape (N, T, C, H, W) encoding
  71. input images. The N denotes batch size.
  72. The T denotes the number of frames in a video.
  73. data_samples (list[:obj:`TrackDataSample`]): The batch
  74. data samples. It usually includes information such
  75. as `video_data_samples`.
  76. rescale (bool, Optional): If False, then returned bboxes and masks
  77. will fit the scale of img, otherwise, returned bboxes and masks
  78. will fit the scale of original image shape. Defaults to True.
  79. Returns:
  80. TrackSampleList: Tracking results of the inputs.
  81. """
  82. assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
  83. assert len(data_samples) == 1, \
  84. 'Mask2former only support 1 batch size per gpu for now.'
  85. # [T, C, H, W]
  86. img = inputs[0]
  87. track_data_sample = data_samples[0]
  88. feats = self.backbone(img)
  89. pred_track_ins_list = self.track_head.predict(feats, track_data_sample,
  90. rescale)
  91. det_data_samples_list = []
  92. for idx, pred_track_ins in enumerate(pred_track_ins_list):
  93. img_data_sample = track_data_sample[idx]
  94. img_data_sample.pred_track_instances = pred_track_ins
  95. det_data_samples_list.append(img_data_sample)
  96. results = TrackDataSample()
  97. results.video_data_samples = det_data_samples_list
  98. return [results]