1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Optional
- import torch
- try:
- import mmpretrain
- from mmpretrain.models.classifiers import ImageClassifier
- except ImportError:
- mmpretrain = None
- ImageClassifier = object
- from mmdet.registry import MODELS
- from mmdet.structures import ReIDDataSample
- @MODELS.register_module()
- class BaseReID(ImageClassifier):
- """Base model for re-identification."""
- def __init__(self, *args, **kwargs):
- if mmpretrain is None:
- raise RuntimeError('Please run "pip install openmim" and '
- 'run "mim install mmpretrain" to '
- 'install mmpretrain first.')
- super().__init__(*args, **kwargs)
- def forward(self,
- inputs: torch.Tensor,
- data_samples: Optional[List[ReIDDataSample]] = None,
- mode: str = 'tensor'):
- """The unified entry for a forward process in both training and test.
- The method should accept three modes: "tensor", "predict" and "loss":
- - "tensor": Forward the whole network and return tensor or tuple of
- tensor without any post-processing, same as a common nn.Module.
- - "predict": Forward and return the predictions, which are fully
- processed to a list of :obj:`ReIDDataSample`.
- - "loss": Forward and return a dict of losses according to the given
- inputs and data samples.
- Note that this method doesn't handle neither back propagation nor
- optimizer updating, which are done in the :meth:`train_step`.
- Args:
- inputs (torch.Tensor): The input tensor with shape
- (N, C, H, W) or (N, T, C, H, W).
- data_samples (List[ReIDDataSample], optional): The annotation
- data of every sample. It's required if ``mode="loss"``.
- Defaults to None.
- mode (str): Return what kind of value. Defaults to 'tensor'.
- Returns:
- The return type depends on ``mode``.
- - If ``mode="tensor"``, return a tensor or a tuple of tensor.
- - If ``mode="predict"``, return a list of
- :obj:`ReIDDataSample`.
- - If ``mode="loss"``, return a dict of tensor.
- """
- if len(inputs.size()) == 5:
- assert inputs.size(0) == 1
- inputs = inputs[0]
- return super().forward(inputs, data_samples, mode)
|