linear_reid_head.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. from typing import List, Optional, Tuple, Union
  4. import torch
  5. import torch.nn as nn
  6. try:
  7. import mmpretrain
  8. from mmpretrain.evaluation.metrics import Accuracy
  9. except ImportError:
  10. mmpretrain = None
  11. from mmengine.model import BaseModule
  12. from mmdet.registry import MODELS
  13. from mmdet.structures import ReIDDataSample
  14. from .fc_module import FcModule
  15. @MODELS.register_module()
  16. class LinearReIDHead(BaseModule):
  17. """Linear head for re-identification.
  18. Args:
  19. num_fcs (int): Number of fcs.
  20. in_channels (int): Number of channels in the input.
  21. fc_channels (int): Number of channels in the fcs.
  22. out_channels (int): Number of channels in the output.
  23. norm_cfg (dict, optional): Configuration of normlization method
  24. after fc. Defaults to None.
  25. act_cfg (dict, optional): Configuration of activation method after fc.
  26. Defaults to None.
  27. num_classes (int, optional): Number of the identities. Default to None.
  28. loss_cls (dict, optional): Cross entropy loss to train the ReID module.
  29. Defaults to None.
  30. loss_triplet (dict, optional): Triplet loss to train the ReID module.
  31. Defaults to None.
  32. topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``.
  33. init_cfg (dict or list[dict], optional): Initialization config dict.
  34. Defaults to dict(type='Normal',layer='Linear', mean=0, std=0.01,
  35. bias=0).
  36. """
  37. def __init__(self,
  38. num_fcs: int,
  39. in_channels: int,
  40. fc_channels: int,
  41. out_channels: int,
  42. norm_cfg: Optional[dict] = None,
  43. act_cfg: Optional[dict] = None,
  44. num_classes: Optional[int] = None,
  45. loss_cls: Optional[dict] = None,
  46. loss_triplet: Optional[dict] = None,
  47. topk: Union[int, Tuple[int]] = (1, ),
  48. init_cfg: Union[dict, List[dict]] = dict(
  49. type='Normal', layer='Linear', mean=0, std=0.01, bias=0)):
  50. if mmpretrain is None:
  51. raise RuntimeError('Please run "pip install openmim" and '
  52. 'run "mim install mmpretrain" to '
  53. 'install mmpretrain first.')
  54. super(LinearReIDHead, self).__init__(init_cfg=init_cfg)
  55. assert isinstance(topk, (int, tuple))
  56. if isinstance(topk, int):
  57. topk = (topk, )
  58. for _topk in topk:
  59. assert _topk > 0, 'Top-k should be larger than 0'
  60. self.topk = topk
  61. if loss_cls is None:
  62. if isinstance(num_classes, int):
  63. warnings.warn('Since cross entropy is not set, '
  64. 'the num_classes will be ignored.')
  65. if loss_triplet is None:
  66. raise ValueError('Please choose at least one loss in '
  67. 'triplet loss and cross entropy loss.')
  68. elif not isinstance(num_classes, int):
  69. raise TypeError('The num_classes must be a current number, '
  70. 'if there is cross entropy loss.')
  71. self.loss_cls = MODELS.build(loss_cls) if loss_cls else None
  72. self.loss_triplet = MODELS.build(loss_triplet) \
  73. if loss_triplet else None
  74. self.num_fcs = num_fcs
  75. self.in_channels = in_channels
  76. self.fc_channels = fc_channels
  77. self.out_channels = out_channels
  78. self.norm_cfg = norm_cfg
  79. self.act_cfg = act_cfg
  80. self.num_classes = num_classes
  81. self._init_layers()
  82. def _init_layers(self):
  83. """Initialize fc layers."""
  84. self.fcs = nn.ModuleList()
  85. for i in range(self.num_fcs):
  86. in_channels = self.in_channels if i == 0 else self.fc_channels
  87. self.fcs.append(
  88. FcModule(in_channels, self.fc_channels, self.norm_cfg,
  89. self.act_cfg))
  90. in_channels = self.in_channels if self.num_fcs == 0 else \
  91. self.fc_channels
  92. self.fc_out = nn.Linear(in_channels, self.out_channels)
  93. if self.loss_cls:
  94. self.bn = nn.BatchNorm1d(self.out_channels)
  95. self.classifier = nn.Linear(self.out_channels, self.num_classes)
  96. def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
  97. """The forward process."""
  98. # Multiple stage inputs are acceptable
  99. # but only the last stage will be used.
  100. feats = feats[-1]
  101. for m in self.fcs:
  102. feats = m(feats)
  103. feats = self.fc_out(feats)
  104. return feats
  105. def loss(self, feats: Tuple[torch.Tensor],
  106. data_samples: List[ReIDDataSample]) -> dict:
  107. """Calculate losses.
  108. Args:
  109. feats (tuple[Tensor]): The features extracted from the backbone.
  110. data_samples (List[ReIDDataSample]): The annotation data of
  111. every samples.
  112. Returns:
  113. dict: a dictionary of loss components
  114. """
  115. # The part can be traced by torch.fx
  116. feats = self(feats)
  117. # The part can not be traced by torch.fx
  118. losses = self.loss_by_feat(feats, data_samples)
  119. return losses
  120. def loss_by_feat(self, feats: torch.Tensor,
  121. data_samples: List[ReIDDataSample]) -> dict:
  122. """Unpack data samples and compute loss."""
  123. losses = dict()
  124. gt_label = torch.cat([i.gt_label.label for i in data_samples])
  125. gt_label = gt_label.to(feats.device)
  126. if self.loss_triplet:
  127. losses['triplet_loss'] = self.loss_triplet(feats, gt_label)
  128. if self.loss_cls:
  129. feats_bn = self.bn(feats)
  130. cls_score = self.classifier(feats_bn)
  131. losses['ce_loss'] = self.loss_cls(cls_score, gt_label)
  132. acc = Accuracy.calculate(cls_score, gt_label, topk=self.topk)
  133. losses.update(
  134. {f'accuracy_top-{k}': a
  135. for k, a in zip(self.topk, acc)})
  136. return losses
  137. def predict(
  138. self,
  139. feats: Tuple[torch.Tensor],
  140. data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]:
  141. """Inference without augmentation.
  142. Args:
  143. feats (Tuple[Tensor]): The features extracted from the backbone.
  144. Multiple stage inputs are acceptable but only the last stage
  145. will be used.
  146. data_samples (List[ReIDDataSample], optional): The annotation
  147. data of every samples. If not None, set ``pred_label`` of
  148. the input data samples. Defaults to None.
  149. Returns:
  150. List[ReIDDataSample]: A list of data samples which contains the
  151. predicted results.
  152. """
  153. # The part can be traced by torch.fx
  154. feats = self(feats)
  155. # The part can not be traced by torch.fx
  156. data_samples = self.predict_by_feat(feats, data_samples)
  157. return data_samples
  158. def predict_by_feat(
  159. self,
  160. feats: torch.Tensor,
  161. data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]:
  162. """Add prediction features to data samples."""
  163. if data_samples is not None:
  164. for data_sample, feat in zip(data_samples, feats):
  165. data_sample.pred_feature = feat
  166. else:
  167. data_samples = []
  168. for feat in feats:
  169. data_sample = ReIDDataSample()
  170. data_sample.pred_feature = feat
  171. data_samples.append(data_sample)
  172. return data_samples