quasi_dense_embed_head.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import ConvModule
  6. from mmengine.model import BaseModule
  7. from torch import Tensor
  8. from torch.nn.modules.utils import _pair
  9. from mmdet.models.task_modules import SamplingResult
  10. from mmdet.registry import MODELS
  11. from ..task_modules.tracking import embed_similarity
  12. @MODELS.register_module()
  13. class QuasiDenseEmbedHead(BaseModule):
  14. """The quasi-dense roi embed head.
  15. Args:
  16. embed_channels (int): The input channel of embed features.
  17. Defaults to 256.
  18. softmax_temp (int): Softmax temperature. Defaults to -1.
  19. loss_track (dict): The loss function for tracking. Defaults to
  20. MultiPosCrossEntropyLoss.
  21. loss_track_aux (dict): The auxiliary loss function for tracking.
  22. Defaults to MarginL2Loss.
  23. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  24. dict]): Initialization config dict.
  25. """
  26. def __init__(self,
  27. num_convs: int = 0,
  28. num_fcs: int = 0,
  29. roi_feat_size: int = 7,
  30. in_channels: int = 256,
  31. conv_out_channels: int = 256,
  32. with_avg_pool: bool = False,
  33. fc_out_channels: int = 1024,
  34. conv_cfg: Optional[dict] = None,
  35. norm_cfg: Optional[dict] = None,
  36. embed_channels: int = 256,
  37. softmax_temp: int = -1,
  38. loss_track: Optional[dict] = None,
  39. loss_track_aux: dict = dict(
  40. type='MarginL2Loss',
  41. sample_ratio=3,
  42. margin=0.3,
  43. loss_weight=1.0,
  44. hard_mining=True),
  45. init_cfg: dict = dict(
  46. type='Xavier',
  47. layer='Linear',
  48. distribution='uniform',
  49. bias=0,
  50. override=dict(
  51. type='Normal',
  52. name='fc_embed',
  53. mean=0,
  54. std=0.01,
  55. bias=0))):
  56. super(QuasiDenseEmbedHead, self).__init__(init_cfg=init_cfg)
  57. self.num_convs = num_convs
  58. self.num_fcs = num_fcs
  59. self.roi_feat_size = _pair(roi_feat_size)
  60. self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
  61. self.in_channels = in_channels
  62. self.conv_out_channels = conv_out_channels
  63. self.with_avg_pool = with_avg_pool
  64. self.fc_out_channels = fc_out_channels
  65. self.conv_cfg = conv_cfg
  66. self.norm_cfg = norm_cfg
  67. if self.with_avg_pool:
  68. self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
  69. # add convs and fcs
  70. self.convs, self.fcs, self.last_layer_dim = self._add_conv_fc_branch(
  71. self.num_convs, self.num_fcs, self.in_channels)
  72. self.relu = nn.ReLU(inplace=True)
  73. if loss_track is None:
  74. loss_track = dict(
  75. type='MultiPosCrossEntropyLoss', loss_weight=0.25)
  76. self.fc_embed = nn.Linear(self.last_layer_dim, embed_channels)
  77. self.softmax_temp = softmax_temp
  78. self.loss_track = MODELS.build(loss_track)
  79. if loss_track_aux is not None:
  80. self.loss_track_aux = MODELS.build(loss_track_aux)
  81. else:
  82. self.loss_track_aux = None
  83. def _add_conv_fc_branch(
  84. self, num_branch_convs: int, num_branch_fcs: int,
  85. in_channels: int) -> Tuple[nn.ModuleList, nn.ModuleList, int]:
  86. """Add shared or separable branch. convs -> avg pool (optional) -> fcs.
  87. Args:
  88. num_branch_convs (int): The number of convoluational layers.
  89. num_branch_fcs (int): The number of fully connection layers.
  90. in_channels (int): The input channel of roi features.
  91. Returns:
  92. Tuple[nn.ModuleList, nn.ModuleList, int]: The convs, fcs and the
  93. last layer dimension.
  94. """
  95. last_layer_dim = in_channels
  96. # add branch specific conv layers
  97. branch_convs = nn.ModuleList()
  98. if num_branch_convs > 0:
  99. for i in range(num_branch_convs):
  100. conv_in_channels = (
  101. last_layer_dim if i == 0 else self.conv_out_channels)
  102. branch_convs.append(
  103. ConvModule(
  104. conv_in_channels,
  105. self.conv_out_channels,
  106. 3,
  107. padding=1,
  108. conv_cfg=self.conv_cfg,
  109. norm_cfg=self.norm_cfg))
  110. last_layer_dim = self.conv_out_channels
  111. # add branch specific fc layers
  112. branch_fcs = nn.ModuleList()
  113. if num_branch_fcs > 0:
  114. if not self.with_avg_pool:
  115. last_layer_dim *= self.roi_feat_area
  116. for i in range(num_branch_fcs):
  117. fc_in_channels = (
  118. last_layer_dim if i == 0 else self.fc_out_channels)
  119. branch_fcs.append(
  120. nn.Linear(fc_in_channels, self.fc_out_channels))
  121. last_layer_dim = self.fc_out_channels
  122. return branch_convs, branch_fcs, last_layer_dim
  123. def forward(self, x: Tensor) -> Tensor:
  124. """Forward function.
  125. Args:
  126. x (Tensor): The input features from ROI head.
  127. Returns:
  128. Tensor: The embedding feature map.
  129. """
  130. if self.num_convs > 0:
  131. for conv in self.convs:
  132. x = conv(x)
  133. x = x.flatten(1)
  134. if self.num_fcs > 0:
  135. for fc in self.fcs:
  136. x = self.relu(fc(x))
  137. x = self.fc_embed(x)
  138. return x
  139. def get_targets(
  140. self, gt_match_indices: List[Tensor],
  141. key_sampling_results: List[SamplingResult],
  142. ref_sampling_results: List[SamplingResult]) -> Tuple[List, List]:
  143. """Calculate the track targets and track weights for all samples in a
  144. batch according to the sampling_results.
  145. Args:
  146. gt_match_indices (list(Tensor)): Mapping from gt_instance_ids to
  147. ref_gt_instance_ids of the same tracklet in a pair of images.
  148. key_sampling_results (List[obj:SamplingResult]): Assign results of
  149. all images in a batch after sampling.
  150. ref_sampling_results (List[obj:SamplingResult]): Assign results of
  151. all reference images in a batch after sampling.
  152. Returns:
  153. Tuple[list[Tensor]]: Association results.
  154. Containing the following list of Tensors:
  155. - track_targets (list[Tensor]): The mapping instance ids from
  156. all positive proposals in the key image to all proposals
  157. in the reference image, each tensor in list has
  158. shape (len(key_pos_bboxes), len(ref_bboxes)).
  159. - track_weights (list[Tensor]): Loss weights for all positive
  160. proposals in a batch, each tensor in list has
  161. shape (len(key_pos_bboxes),).
  162. """
  163. track_targets = []
  164. track_weights = []
  165. for _gt_match_indices, key_res, ref_res in zip(gt_match_indices,
  166. key_sampling_results,
  167. ref_sampling_results):
  168. targets = _gt_match_indices.new_zeros(
  169. (key_res.pos_bboxes.size(0), ref_res.bboxes.size(0)),
  170. dtype=torch.int)
  171. _match_indices = _gt_match_indices[key_res.pos_assigned_gt_inds]
  172. pos2pos = (_match_indices.view(
  173. -1, 1) == ref_res.pos_assigned_gt_inds.view(1, -1)).int()
  174. targets[:, :pos2pos.size(1)] = pos2pos
  175. weights = (targets.sum(dim=1) > 0).float()
  176. track_targets.append(targets)
  177. track_weights.append(weights)
  178. return track_targets, track_weights
  179. def match(
  180. self, key_embeds: Tensor, ref_embeds: Tensor,
  181. key_sampling_results: List[SamplingResult],
  182. ref_sampling_results: List[SamplingResult]
  183. ) -> Tuple[List[Tensor], List[Tensor]]:
  184. """Calculate the dist matrixes for loss measurement.
  185. Args:
  186. key_embeds (Tensor): Embeds of positive bboxes in sampling results
  187. of key image.
  188. ref_embeds (Tensor): Embeds of all bboxes in sampling results
  189. of the reference image.
  190. key_sampling_results (List[obj:SamplingResults]): Assign results of
  191. all images in a batch after sampling.
  192. ref_sampling_results (List[obj:SamplingResults]): Assign results of
  193. all reference images in a batch after sampling.
  194. Returns:
  195. Tuple[list[Tensor]]: Calculation results.
  196. Containing the following list of Tensors:
  197. - dists (list[Tensor]): Dot-product dists between
  198. key_embeds and ref_embeds, each tensor in list has
  199. shape (len(key_pos_bboxes), len(ref_bboxes)).
  200. - cos_dists (list[Tensor]): Cosine dists between
  201. key_embeds and ref_embeds, each tensor in list has
  202. shape (len(key_pos_bboxes), len(ref_bboxes)).
  203. """
  204. num_key_rois = [res.pos_bboxes.size(0) for res in key_sampling_results]
  205. key_embeds = torch.split(key_embeds, num_key_rois)
  206. num_ref_rois = [res.bboxes.size(0) for res in ref_sampling_results]
  207. ref_embeds = torch.split(ref_embeds, num_ref_rois)
  208. dists, cos_dists = [], []
  209. for key_embed, ref_embed in zip(key_embeds, ref_embeds):
  210. dist = embed_similarity(
  211. key_embed,
  212. ref_embed,
  213. method='dot_product',
  214. temperature=self.softmax_temp)
  215. dists.append(dist)
  216. if self.loss_track_aux is not None:
  217. cos_dist = embed_similarity(
  218. key_embed, ref_embed, method='cosine')
  219. cos_dists.append(cos_dist)
  220. else:
  221. cos_dists.append(None)
  222. return dists, cos_dists
  223. def loss(self, key_roi_feats: Tensor, ref_roi_feats: Tensor,
  224. key_sampling_results: List[SamplingResult],
  225. ref_sampling_results: List[SamplingResult],
  226. gt_match_indices_list: List[Tensor]) -> dict:
  227. """Calculate the track loss and the auxiliary track loss.
  228. Args:
  229. key_roi_feats (Tensor): Embeds of positive bboxes in sampling
  230. results of key image.
  231. ref_roi_feats (Tensor): Embeds of all bboxes in sampling results
  232. of the reference image.
  233. key_sampling_results (List[obj:SamplingResults]): Assign results of
  234. all images in a batch after sampling.
  235. ref_sampling_results (List[obj:SamplingResults]): Assign results of
  236. all reference images in a batch after sampling.
  237. gt_match_indices_list (list(Tensor)): Mapping from gt_instances_ids
  238. to ref_gt_instances_ids of the same tracklet in a pair of
  239. images.
  240. Returns:
  241. Dict [str: Tensor]: Calculation results.
  242. Containing the following list of Tensors:
  243. - loss_track (Tensor): Results of loss_track function.
  244. - loss_track_aux (Tensor): Results of loss_track_aux function.
  245. """
  246. key_track_feats = self(key_roi_feats)
  247. ref_track_feats = self(ref_roi_feats)
  248. losses = self.loss_by_feat(key_track_feats, ref_track_feats,
  249. key_sampling_results, ref_sampling_results,
  250. gt_match_indices_list)
  251. return losses
  252. def loss_by_feat(self, key_track_feats: Tensor, ref_track_feats: Tensor,
  253. key_sampling_results: List[SamplingResult],
  254. ref_sampling_results: List[SamplingResult],
  255. gt_match_indices_list: List[Tensor]) -> dict:
  256. """Calculate the track loss and the auxiliary track loss.
  257. Args:
  258. key_track_feats (Tensor): Embeds of positive bboxes in sampling
  259. results of key image.
  260. ref_track_feats (Tensor): Embeds of all bboxes in sampling results
  261. of the reference image.
  262. key_sampling_results (List[obj:SamplingResults]): Assign results of
  263. all images in a batch after sampling.
  264. ref_sampling_results (List[obj:SamplingResults]): Assign results of
  265. all reference images in a batch after sampling.
  266. gt_match_indices_list (list(Tensor)): Mapping from instances_ids
  267. from key image to reference image of the same tracklet in a
  268. pair of images.
  269. Returns:
  270. Dict [str: Tensor]: Calculation results.
  271. Containing the following list of Tensors:
  272. - loss_track (Tensor): Results of loss_track function.
  273. - loss_track_aux (Tensor): Results of loss_track_aux function.
  274. """
  275. dists, cos_dists = self.match(key_track_feats, ref_track_feats,
  276. key_sampling_results,
  277. ref_sampling_results)
  278. targets, weights = self.get_targets(gt_match_indices_list,
  279. key_sampling_results,
  280. ref_sampling_results)
  281. losses = dict()
  282. loss_track = 0.
  283. loss_track_aux = 0.
  284. for _dists, _cos_dists, _targets, _weights in zip(
  285. dists, cos_dists, targets, weights):
  286. loss_track += self.loss_track(
  287. _dists, _targets, _weights, avg_factor=_weights.sum())
  288. if self.loss_track_aux is not None:
  289. loss_track_aux += self.loss_track_aux(_cos_dists, _targets)
  290. losses['loss_track'] = loss_track / len(dists)
  291. if self.loss_track_aux is not None:
  292. losses['loss_track_aux'] = loss_track_aux / len(dists)
  293. return losses
  294. def predict(self, bbox_feats: Tensor) -> Tensor:
  295. """Perform forward propagation of the tracking head and predict
  296. tracking results on the features of the upstream network.
  297. Args:
  298. bbox_feats: The extracted roi features.
  299. Returns:
  300. Tensor: The extracted track features.
  301. """
  302. track_feats = self(bbox_feats)
  303. return track_feats