roi_embed_head.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from collections import defaultdict
  3. from typing import List, Optional, Tuple
  4. import torch
  5. import torch.nn as nn
  6. from mmcv.cnn import ConvModule
  7. from mmengine.model import BaseModule
  8. from torch import Tensor
  9. from torch.nn.modules.utils import _pair
  10. from mmdet.models.losses import accuracy
  11. from mmdet.models.task_modules import SamplingResult
  12. from mmdet.models.task_modules.tracking import embed_similarity
  13. from mmdet.registry import MODELS
  14. @MODELS.register_module()
  15. class RoIEmbedHead(BaseModule):
  16. """The roi embed head.
  17. This module is used in multi-object tracking methods, such as MaskTrack
  18. R-CNN.
  19. Args:
  20. num_convs (int): The number of convoluational layers to embed roi
  21. features. Defaults to 0.
  22. num_fcs (int): The number of fully connection layers to embed roi
  23. features. Defaults to 0.
  24. roi_feat_size (int|tuple(int)): The spatial size of roi features.
  25. Defaults to 7.
  26. in_channels (int): The input channel of roi features. Defaults to 256.
  27. conv_out_channels (int): The output channel of roi features after
  28. forwarding convoluational layers. Defaults to 256.
  29. with_avg_pool (bool): Whether use average pooling before passing roi
  30. features into fully connection layers. Defaults to False.
  31. fc_out_channels (int): The output channel of roi features after
  32. forwarding fully connection layers. Defaults to 1024.
  33. conv_cfg (dict): Config dict for convolution layer. Defaults to None,
  34. which means using conv2d.
  35. norm_cfg (dict): Config dict for normalization layer. Defaults to None.
  36. loss_match (dict): The loss function. Defaults to
  37. dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
  38. init_cfg (dict): Configuration of initialization. Defaults to None.
  39. """
  40. def __init__(self,
  41. num_convs: int = 0,
  42. num_fcs: int = 0,
  43. roi_feat_size: int = 7,
  44. in_channels: int = 256,
  45. conv_out_channels: int = 256,
  46. with_avg_pool: bool = False,
  47. fc_out_channels: int = 1024,
  48. conv_cfg: Optional[dict] = None,
  49. norm_cfg: Optional[dict] = None,
  50. loss_match: dict = dict(
  51. type='mmdet.CrossEntropyLoss',
  52. use_sigmoid=False,
  53. loss_weight=1.0),
  54. init_cfg: Optional[dict] = None,
  55. **kwargs):
  56. super(RoIEmbedHead, self).__init__(init_cfg=init_cfg)
  57. self.num_convs = num_convs
  58. self.num_fcs = num_fcs
  59. self.roi_feat_size = _pair(roi_feat_size)
  60. self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
  61. self.in_channels = in_channels
  62. self.conv_out_channels = conv_out_channels
  63. self.with_avg_pool = with_avg_pool
  64. self.fc_out_channels = fc_out_channels
  65. self.conv_cfg = conv_cfg
  66. self.norm_cfg = norm_cfg
  67. self.loss_match = MODELS.build(loss_match)
  68. self.fp16_enabled = False
  69. if self.with_avg_pool:
  70. self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
  71. # add convs and fcs
  72. self.convs, self.fcs, self.last_layer_dim = self._add_conv_fc_branch(
  73. self.num_convs, self.num_fcs, self.in_channels)
  74. self.relu = nn.ReLU(inplace=True)
  75. def _add_conv_fc_branch(
  76. self, num_branch_convs: int, num_branch_fcs: int,
  77. in_channels: int) -> Tuple[nn.ModuleList, nn.ModuleList, int]:
  78. """Add shared or separable branch.
  79. convs -> avg pool (optional) -> fcs
  80. """
  81. last_layer_dim = in_channels
  82. # add branch specific conv layers
  83. branch_convs = nn.ModuleList()
  84. if num_branch_convs > 0:
  85. for i in range(num_branch_convs):
  86. conv_in_channels = (
  87. last_layer_dim if i == 0 else self.conv_out_channels)
  88. branch_convs.append(
  89. ConvModule(
  90. conv_in_channels,
  91. self.conv_out_channels,
  92. 3,
  93. padding=1,
  94. conv_cfg=self.conv_cfg,
  95. norm_cfg=self.norm_cfg))
  96. last_layer_dim = self.conv_out_channels
  97. # add branch specific fc layers
  98. branch_fcs = nn.ModuleList()
  99. if num_branch_fcs > 0:
  100. if not self.with_avg_pool:
  101. last_layer_dim *= self.roi_feat_area
  102. for i in range(num_branch_fcs):
  103. fc_in_channels = (
  104. last_layer_dim if i == 0 else self.fc_out_channels)
  105. branch_fcs.append(
  106. nn.Linear(fc_in_channels, self.fc_out_channels))
  107. last_layer_dim = self.fc_out_channels
  108. return branch_convs, branch_fcs, last_layer_dim
  109. @property
  110. def custom_activation(self):
  111. return getattr(self.loss_match, 'custom_activation', False)
  112. def extract_feat(self, x: Tensor,
  113. num_x_per_img: List[int]) -> Tuple[Tensor]:
  114. """Extract feature from the input `x`, and split the output to a list.
  115. Args:
  116. x (Tensor): of shape [N, C, H, W]. N is the number of proposals.
  117. num_x_per_img (list[int]): The `x` contains proposals of
  118. multi-images. `num_x_per_img` denotes the number of proposals
  119. for each image.
  120. Returns:
  121. list[Tensor]: Each Tensor denotes the embed features belonging to
  122. an image in a batch.
  123. """
  124. if self.num_convs > 0:
  125. for conv in self.convs:
  126. x = conv(x)
  127. if self.num_fcs > 0:
  128. if self.with_avg_pool:
  129. x = self.avg_pool(x)
  130. x = x.flatten(1)
  131. for fc in self.fcs:
  132. x = self.relu(fc(x))
  133. else:
  134. x = x.flatten(1)
  135. x_split = torch.split(x, num_x_per_img, dim=0)
  136. return x_split
  137. def forward(
  138. self, x: Tensor, ref_x: Tensor, num_x_per_img: List[int],
  139. num_x_per_ref_img: List[int]
  140. ) -> Tuple[Tuple[Tensor], Tuple[Tensor]]:
  141. """Computing the similarity scores between `x` and `ref_x`.
  142. Args:
  143. x (Tensor): of shape [N, C, H, W]. N is the number of key frame
  144. proposals.
  145. ref_x (Tensor): of shape [M, C, H, W]. M is the number of reference
  146. frame proposals.
  147. num_x_per_img (list[int]): The `x` contains proposals of
  148. multi-images. `num_x_per_img` denotes the number of proposals
  149. for each key image.
  150. num_x_per_ref_img (list[int]): The `ref_x` contains proposals of
  151. multi-images. `num_x_per_ref_img` denotes the number of
  152. proposals for each reference image.
  153. Returns:
  154. tuple[tuple[Tensor], tuple[Tensor]]: Each tuple of tensor denotes
  155. the embed features belonging to an image in a batch.
  156. """
  157. x_split = self.extract_feat(x, num_x_per_img)
  158. ref_x_split = self.extract_feat(ref_x, num_x_per_ref_img)
  159. return x_split, ref_x_split
  160. def get_targets(self, sampling_results: List[SamplingResult],
  161. gt_instance_ids: List[Tensor],
  162. ref_gt_instance_ids: List[Tensor]) -> Tuple[List, List]:
  163. """Calculate the ground truth for all samples in a batch according to
  164. the sampling_results.
  165. Args:
  166. sampling_results (List[obj:SamplingResult]): Assign results of
  167. all images in a batch after sampling.
  168. gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes of
  169. all images in a batch, each tensor has shape (num_gt, ).
  170. ref_gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes
  171. of all reference images in a batch, each tensor has shape
  172. (num_gt, ).
  173. Returns:
  174. Tuple[list[Tensor]]: Ground truth for proposals in a batch.
  175. Containing the following list of Tensors:
  176. - track_id_targets (list[Tensor]): The instance ids of
  177. Gt_labels for all proposals in a batch, each tensor in list
  178. has shape (num_proposals,).
  179. - track_id_weights (list[Tensor]): Labels_weights for
  180. all proposals in a batch, each tensor in list has
  181. shape (num_proposals,).
  182. """
  183. track_id_targets = []
  184. track_id_weights = []
  185. for res, gt_instance_id, ref_gt_instance_id in zip(
  186. sampling_results, gt_instance_ids, ref_gt_instance_ids):
  187. pos_instance_ids = gt_instance_id[res.pos_assigned_gt_inds]
  188. pos_match_id = gt_instance_id.new_zeros(len(pos_instance_ids))
  189. for i, id in enumerate(pos_instance_ids):
  190. if id in ref_gt_instance_id:
  191. pos_match_id[i] = ref_gt_instance_id.tolist().index(id) + 1
  192. track_id_target = gt_instance_id.new_zeros(
  193. len(res.bboxes), dtype=torch.int64)
  194. track_id_target[:len(res.pos_bboxes)] = pos_match_id
  195. track_id_weight = res.bboxes.new_zeros(len(res.bboxes))
  196. track_id_weight[:len(res.pos_bboxes)] = 1.0
  197. track_id_targets.append(track_id_target)
  198. track_id_weights.append(track_id_weight)
  199. return track_id_targets, track_id_weights
  200. def loss(
  201. self,
  202. bbox_feats: Tensor,
  203. ref_bbox_feats: Tensor,
  204. num_bbox_per_img: int,
  205. num_bbox_per_ref_img: int,
  206. sampling_results: List[SamplingResult],
  207. gt_instance_ids: List[Tensor],
  208. ref_gt_instance_ids: List[Tensor],
  209. reduction_override: Optional[str] = None,
  210. ) -> dict:
  211. """Calculate the loss in a batch.
  212. Args:
  213. bbox_feats (Tensor): of shape [N, C, H, W]. N is the number of
  214. bboxes.
  215. ref_bbox_feats (Tensor): of shape [M, C, H, W]. M is the number of
  216. reference bboxes.
  217. num_bbox_per_img (list[int]): The `bbox_feats` contains proposals
  218. of multi-images. `num_bbox_per_img` denotes the number of
  219. proposals for each key image.
  220. num_bbox_per_ref_img (list[int]): The `ref_bbox_feats` contains
  221. proposals of multi-images. `num_bbox_per_ref_img` denotes the
  222. number of proposals for each reference image.
  223. sampling_results (List[obj:SamplingResult]): Assign results of
  224. all images in a batch after sampling.
  225. gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes of
  226. all images in a batch, each tensor has shape (num_gt, ).
  227. ref_gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes
  228. of all reference images in a batch, each tensor has shape
  229. (num_gt, ).
  230. reduction_override (str, optional): The method used to reduce the
  231. loss. Options are "none", "mean" and "sum".
  232. Returns:
  233. dict[str, Tensor]: a dictionary of loss components.
  234. """
  235. x_split, ref_x_split = self(bbox_feats, ref_bbox_feats,
  236. num_bbox_per_img, num_bbox_per_ref_img)
  237. losses = self.loss_by_feat(x_split, ref_x_split, sampling_results,
  238. gt_instance_ids, ref_gt_instance_ids,
  239. reduction_override)
  240. return losses
  241. def loss_by_feat(self,
  242. x_split: Tuple[Tensor],
  243. ref_x_split: Tuple[Tensor],
  244. sampling_results: List[SamplingResult],
  245. gt_instance_ids: List[Tensor],
  246. ref_gt_instance_ids: List[Tensor],
  247. reduction_override: Optional[str] = None) -> dict:
  248. """Calculate losses.
  249. Args:
  250. x_split (Tensor): The embed features belonging to key image.
  251. ref_x_split (Tensor): The embed features belonging to ref image.
  252. sampling_results (List[obj:SamplingResult]): Assign results of
  253. all images in a batch after sampling.
  254. gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes of
  255. all images in a batch, each tensor has shape (num_gt, ).
  256. ref_gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes
  257. of all reference images in a batch, each tensor has shape
  258. (num_gt, ).
  259. reduction_override (str, optional): The method used to reduce the
  260. loss. Options are "none", "mean" and "sum".
  261. Returns:
  262. dict[str, Tensor]: a dictionary of loss components.
  263. """
  264. track_id_targets, track_id_weights = self.get_targets(
  265. sampling_results, gt_instance_ids, ref_gt_instance_ids)
  266. assert isinstance(track_id_targets, list)
  267. assert isinstance(track_id_weights, list)
  268. assert len(track_id_weights) == len(track_id_targets)
  269. losses = defaultdict(list)
  270. similarity_logits = []
  271. for one_x, one_ref_x in zip(x_split, ref_x_split):
  272. similarity_logit = embed_similarity(
  273. one_x, one_ref_x, method='dot_product')
  274. dummy = similarity_logit.new_zeros(one_x.shape[0], 1)
  275. similarity_logit = torch.cat((dummy, similarity_logit), dim=1)
  276. similarity_logits.append(similarity_logit)
  277. assert isinstance(similarity_logits, list)
  278. assert len(similarity_logits) == len(track_id_targets)
  279. for similarity_logit, track_id_target, track_id_weight in zip(
  280. similarity_logits, track_id_targets, track_id_weights):
  281. avg_factor = max(torch.sum(track_id_target > 0).float().item(), 1.)
  282. if similarity_logit.numel() > 0:
  283. loss_match = self.loss_match(
  284. similarity_logit,
  285. track_id_target,
  286. track_id_weight,
  287. avg_factor=avg_factor,
  288. reduction_override=reduction_override)
  289. if isinstance(loss_match, dict):
  290. for key, value in loss_match.items():
  291. losses[key].append(value)
  292. else:
  293. losses['loss_match'].append(loss_match)
  294. valid_index = track_id_weight > 0
  295. valid_similarity_logit = similarity_logit[valid_index]
  296. valid_track_id_target = track_id_target[valid_index]
  297. if self.custom_activation:
  298. match_accuracy = self.loss_match.get_accuracy(
  299. valid_similarity_logit, valid_track_id_target)
  300. for key, value in match_accuracy.items():
  301. losses[key].append(value)
  302. else:
  303. losses['match_accuracy'].append(
  304. accuracy(valid_similarity_logit,
  305. valid_track_id_target))
  306. for key, value in losses.items():
  307. losses[key] = sum(losses[key]) / len(similarity_logits)
  308. return losses
  309. def predict(self, roi_feats: Tensor,
  310. prev_roi_feats: Tensor) -> List[Tensor]:
  311. """Perform forward propagation of the tracking head and predict
  312. tracking results on the features of the upstream network.
  313. Args:
  314. roi_feats (Tensor): Feature map of current images rois.
  315. prev_roi_feats (Tensor): Feature map of previous images rois.
  316. Returns:
  317. list[Tensor]: The predicted similarity_logits of each pair of key
  318. image and reference image.
  319. """
  320. x_split, ref_x_split = self(roi_feats, prev_roi_feats,
  321. [roi_feats.shape[0]],
  322. [prev_roi_feats.shape[0]])
  323. similarity_logits = self.predict_by_feat(x_split, ref_x_split)
  324. return similarity_logits
  325. def predict_by_feat(self, x_split: Tuple[Tensor],
  326. ref_x_split: Tuple[Tensor]) -> List[Tensor]:
  327. """Get similarity_logits.
  328. Args:
  329. x_split (Tensor): The embed features belonging to key image.
  330. ref_x_split (Tensor): The embed features belonging to ref image.
  331. Returns:
  332. list[Tensor]: The predicted similarity_logits of each pair of key
  333. image and reference image.
  334. """
  335. similarity_logits = []
  336. for one_x, one_ref_x in zip(x_split, ref_x_split):
  337. similarity_logit = embed_similarity(
  338. one_x, one_ref_x, method='dot_product')
  339. dummy = similarity_logit.new_zeros(one_x.shape[0], 1)
  340. similarity_logit = torch.cat((dummy, similarity_logit), dim=1)
  341. similarity_logits.append(similarity_logit)
  342. return similarity_logits