123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import warnings
- from typing import List, Optional, Tuple, Union
- import torch
- import torch.nn as nn
- try:
- import mmpretrain
- from mmpretrain.evaluation.metrics import Accuracy
- except ImportError:
- mmpretrain = None
- from mmengine.model import BaseModule
- from mmdet.registry import MODELS
- from mmdet.structures import ReIDDataSample
- from .fc_module import FcModule
- @MODELS.register_module()
- class LinearReIDHead(BaseModule):
- """Linear head for re-identification.
- Args:
- num_fcs (int): Number of fcs.
- in_channels (int): Number of channels in the input.
- fc_channels (int): Number of channels in the fcs.
- out_channels (int): Number of channels in the output.
- norm_cfg (dict, optional): Configuration of normlization method
- after fc. Defaults to None.
- act_cfg (dict, optional): Configuration of activation method after fc.
- Defaults to None.
- num_classes (int, optional): Number of the identities. Default to None.
- loss_cls (dict, optional): Cross entropy loss to train the ReID module.
- Defaults to None.
- loss_triplet (dict, optional): Triplet loss to train the ReID module.
- Defaults to None.
- topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- Defaults to dict(type='Normal',layer='Linear', mean=0, std=0.01,
- bias=0).
- """
- def __init__(self,
- num_fcs: int,
- in_channels: int,
- fc_channels: int,
- out_channels: int,
- norm_cfg: Optional[dict] = None,
- act_cfg: Optional[dict] = None,
- num_classes: Optional[int] = None,
- loss_cls: Optional[dict] = None,
- loss_triplet: Optional[dict] = None,
- topk: Union[int, Tuple[int]] = (1, ),
- init_cfg: Union[dict, List[dict]] = dict(
- type='Normal', layer='Linear', mean=0, std=0.01, bias=0)):
- if mmpretrain is None:
- raise RuntimeError('Please run "pip install openmim" and '
- 'run "mim install mmpretrain" to '
- 'install mmpretrain first.')
- super(LinearReIDHead, self).__init__(init_cfg=init_cfg)
- assert isinstance(topk, (int, tuple))
- if isinstance(topk, int):
- topk = (topk, )
- for _topk in topk:
- assert _topk > 0, 'Top-k should be larger than 0'
- self.topk = topk
- if loss_cls is None:
- if isinstance(num_classes, int):
- warnings.warn('Since cross entropy is not set, '
- 'the num_classes will be ignored.')
- if loss_triplet is None:
- raise ValueError('Please choose at least one loss in '
- 'triplet loss and cross entropy loss.')
- elif not isinstance(num_classes, int):
- raise TypeError('The num_classes must be a current number, '
- 'if there is cross entropy loss.')
- self.loss_cls = MODELS.build(loss_cls) if loss_cls else None
- self.loss_triplet = MODELS.build(loss_triplet) \
- if loss_triplet else None
- self.num_fcs = num_fcs
- self.in_channels = in_channels
- self.fc_channels = fc_channels
- self.out_channels = out_channels
- self.norm_cfg = norm_cfg
- self.act_cfg = act_cfg
- self.num_classes = num_classes
- self._init_layers()
- def _init_layers(self):
- """Initialize fc layers."""
- self.fcs = nn.ModuleList()
- for i in range(self.num_fcs):
- in_channels = self.in_channels if i == 0 else self.fc_channels
- self.fcs.append(
- FcModule(in_channels, self.fc_channels, self.norm_cfg,
- self.act_cfg))
- in_channels = self.in_channels if self.num_fcs == 0 else \
- self.fc_channels
- self.fc_out = nn.Linear(in_channels, self.out_channels)
- if self.loss_cls:
- self.bn = nn.BatchNorm1d(self.out_channels)
- self.classifier = nn.Linear(self.out_channels, self.num_classes)
- def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
- """The forward process."""
- # Multiple stage inputs are acceptable
- # but only the last stage will be used.
- feats = feats[-1]
- for m in self.fcs:
- feats = m(feats)
- feats = self.fc_out(feats)
- return feats
- def loss(self, feats: Tuple[torch.Tensor],
- data_samples: List[ReIDDataSample]) -> dict:
- """Calculate losses.
- Args:
- feats (tuple[Tensor]): The features extracted from the backbone.
- data_samples (List[ReIDDataSample]): The annotation data of
- every samples.
- Returns:
- dict: a dictionary of loss components
- """
- # The part can be traced by torch.fx
- feats = self(feats)
- # The part can not be traced by torch.fx
- losses = self.loss_by_feat(feats, data_samples)
- return losses
- def loss_by_feat(self, feats: torch.Tensor,
- data_samples: List[ReIDDataSample]) -> dict:
- """Unpack data samples and compute loss."""
- losses = dict()
- gt_label = torch.cat([i.gt_label.label for i in data_samples])
- gt_label = gt_label.to(feats.device)
- if self.loss_triplet:
- losses['triplet_loss'] = self.loss_triplet(feats, gt_label)
- if self.loss_cls:
- feats_bn = self.bn(feats)
- cls_score = self.classifier(feats_bn)
- losses['ce_loss'] = self.loss_cls(cls_score, gt_label)
- acc = Accuracy.calculate(cls_score, gt_label, topk=self.topk)
- losses.update(
- {f'accuracy_top-{k}': a
- for k, a in zip(self.topk, acc)})
- return losses
- def predict(
- self,
- feats: Tuple[torch.Tensor],
- data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]:
- """Inference without augmentation.
- Args:
- feats (Tuple[Tensor]): The features extracted from the backbone.
- Multiple stage inputs are acceptable but only the last stage
- will be used.
- data_samples (List[ReIDDataSample], optional): The annotation
- data of every samples. If not None, set ``pred_label`` of
- the input data samples. Defaults to None.
- Returns:
- List[ReIDDataSample]: A list of data samples which contains the
- predicted results.
- """
- # The part can be traced by torch.fx
- feats = self(feats)
- # The part can not be traced by torch.fx
- data_samples = self.predict_by_feat(feats, data_samples)
- return data_samples
- def predict_by_feat(
- self,
- feats: torch.Tensor,
- data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]:
- """Add prediction features to data samples."""
- if data_samples is not None:
- for data_sample, feat in zip(data_samples, feats):
- data_sample.pred_feature = feat
- else:
- data_samples = []
- for feat in feats:
- data_sample = ReIDDataSample()
- data_sample.pred_feature = feat
- data_samples.append(data_sample)
- return data_samples
|