reid_data_sample.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from numbers import Number
  3. from typing import Sequence, Union
  4. import mmengine
  5. import numpy as np
  6. import torch
  7. from mmengine.structures import BaseDataElement, LabelData
  8. def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
  9. num_classes: int = None) -> LabelData:
  10. """Convert label of various python types to :obj:`mmengine.LabelData`.
  11. Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
  12. :class:`Sequence`, :class:`int`.
  13. Args:
  14. value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
  15. num_classes (int, optional): The number of classes. If not None, set
  16. it to the metainfo. Defaults to None.
  17. Returns:
  18. :obj:`mmengine.LabelData`: The foramtted label data.
  19. """
  20. # Handle single number
  21. if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
  22. value = int(value.item())
  23. if isinstance(value, np.ndarray):
  24. value = torch.from_numpy(value)
  25. elif isinstance(value, Sequence) and not mmengine.utils.is_str(value):
  26. value = torch.tensor(value)
  27. elif isinstance(value, int):
  28. value = torch.LongTensor([value])
  29. elif not isinstance(value, torch.Tensor):
  30. raise TypeError(f'Type {type(value)} is not an available label type.')
  31. metainfo = {}
  32. if num_classes is not None:
  33. metainfo['num_classes'] = num_classes
  34. if value.max() >= num_classes:
  35. raise ValueError(f'The label data ({value}) should not '
  36. f'exceed num_classes ({num_classes}).')
  37. label = LabelData(label=value, metainfo=metainfo)
  38. return label
  39. class ReIDDataSample(BaseDataElement):
  40. """A data structure interface of ReID task.
  41. It's used as interfaces between different components.
  42. Meta field:
  43. img_shape (Tuple): The shape of the corresponding input image.
  44. Used for visualization.
  45. ori_shape (Tuple): The original shape of the corresponding image.
  46. Used for visualization.
  47. num_classes (int): The number of all categories.
  48. Used for label format conversion.
  49. Data field:
  50. gt_label (LabelData): The ground truth label.
  51. pred_label (LabelData): The predicted label.
  52. scores (torch.Tensor): The outputs of model.
  53. """
  54. @property
  55. def gt_label(self):
  56. return self._gt_label
  57. @gt_label.setter
  58. def gt_label(self, value: LabelData):
  59. self.set_field(value, '_gt_label', dtype=LabelData)
  60. @gt_label.deleter
  61. def gt_label(self):
  62. del self._gt_label
  63. def set_gt_label(
  64. self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
  65. ) -> 'ReIDDataSample':
  66. """Set label of ``gt_label``."""
  67. label = format_label(value, self.get('num_classes'))
  68. if 'gt_label' in self: # setting for the second time
  69. self.gt_label.label = label.label
  70. else: # setting for the first time
  71. self.gt_label = label
  72. return self
  73. def set_gt_score(self, value: torch.Tensor) -> 'ReIDDataSample':
  74. """Set score of ``gt_label``."""
  75. assert isinstance(value, torch.Tensor), \
  76. f'The value should be a torch.Tensor but got {type(value)}.'
  77. assert value.ndim == 1, \
  78. f'The dims of value should be 1, but got {value.ndim}.'
  79. if 'num_classes' in self:
  80. assert value.size(0) == self.num_classes, \
  81. f"The length of value ({value.size(0)}) doesn't "\
  82. f'match the num_classes ({self.num_classes}).'
  83. metainfo = {'num_classes': self.num_classes}
  84. else:
  85. metainfo = {'num_classes': value.size(0)}
  86. if 'gt_label' in self: # setting for the second time
  87. self.gt_label.score = value
  88. else: # setting for the first time
  89. self.gt_label = LabelData(score=value, metainfo=metainfo)
  90. return self
  91. @property
  92. def pred_feature(self):
  93. return self._pred_feature
  94. @pred_feature.setter
  95. def pred_feature(self, value: torch.Tensor):
  96. self.set_field(value, '_pred_feature', dtype=torch.Tensor)
  97. @pred_feature.deleter
  98. def pred_feature(self):
  99. del self._pred_feature