base_reid.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional
  3. import torch
  4. try:
  5. import mmpretrain
  6. from mmpretrain.models.classifiers import ImageClassifier
  7. except ImportError:
  8. mmpretrain = None
  9. ImageClassifier = object
  10. from mmdet.registry import MODELS
  11. from mmdet.structures import ReIDDataSample
  12. @MODELS.register_module()
  13. class BaseReID(ImageClassifier):
  14. """Base model for re-identification."""
  15. def __init__(self, *args, **kwargs):
  16. if mmpretrain is None:
  17. raise RuntimeError('Please run "pip install openmim" and '
  18. 'run "mim install mmpretrain" to '
  19. 'install mmpretrain first.')
  20. super().__init__(*args, **kwargs)
  21. def forward(self,
  22. inputs: torch.Tensor,
  23. data_samples: Optional[List[ReIDDataSample]] = None,
  24. mode: str = 'tensor'):
  25. """The unified entry for a forward process in both training and test.
  26. The method should accept three modes: "tensor", "predict" and "loss":
  27. - "tensor": Forward the whole network and return tensor or tuple of
  28. tensor without any post-processing, same as a common nn.Module.
  29. - "predict": Forward and return the predictions, which are fully
  30. processed to a list of :obj:`ReIDDataSample`.
  31. - "loss": Forward and return a dict of losses according to the given
  32. inputs and data samples.
  33. Note that this method doesn't handle neither back propagation nor
  34. optimizer updating, which are done in the :meth:`train_step`.
  35. Args:
  36. inputs (torch.Tensor): The input tensor with shape
  37. (N, C, H, W) or (N, T, C, H, W).
  38. data_samples (List[ReIDDataSample], optional): The annotation
  39. data of every sample. It's required if ``mode="loss"``.
  40. Defaults to None.
  41. mode (str): Return what kind of value. Defaults to 'tensor'.
  42. Returns:
  43. The return type depends on ``mode``.
  44. - If ``mode="tensor"``, return a tensor or a tuple of tensor.
  45. - If ``mode="predict"``, return a list of
  46. :obj:`ReIDDataSample`.
  47. - If ``mode="loss"``, return a dict of tensor.
  48. """
  49. if len(inputs.size()) == 5:
  50. assert inputs.size(0) == 1
  51. inputs = inputs[0]
  52. return super().forward(inputs, data_samples, mode)