mask2former_track_head.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from collections import defaultdict
  4. from typing import Dict, List, Tuple
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from mmcv.cnn import Conv2d
  9. from mmcv.ops import point_sample
  10. from mmengine.model import ModuleList
  11. from mmengine.model.weight_init import caffe2_xavier_init
  12. from mmengine.structures import InstanceData
  13. from torch import Tensor
  14. from mmdet.models.dense_heads import AnchorFreeHead, MaskFormerHead
  15. from mmdet.models.utils import get_uncertain_point_coords_with_randomness
  16. from mmdet.registry import MODELS, TASK_UTILS
  17. from mmdet.structures import TrackDataSample, TrackSampleList
  18. from mmdet.structures.mask import mask2bbox
  19. from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
  20. OptMultiConfig, reduce_mean)
  21. from ..layers import Mask2FormerTransformerDecoder
  22. @MODELS.register_module()
  23. class Mask2FormerTrackHead(MaskFormerHead):
  24. """Implements the Mask2Former head.
  25. See `Masked-attention Mask Transformer for Universal Image
  26. Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
  27. Args:
  28. in_channels (list[int]): Number of channels in the input feature map.
  29. feat_channels (int): Number of channels for features.
  30. out_channels (int): Number of channels for output.
  31. num_classes (int): Number of VIS classes.
  32. num_queries (int): Number of query in Transformer decoder.
  33. Defaults to 100.
  34. num_transformer_feat_level (int): Number of feats levels.
  35. Defaults to 3.
  36. pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel
  37. decoder.
  38. enforce_decoder_input_project (bool, optional): Whether to add
  39. a layer to change the embed_dim of transformer encoder in
  40. pixel decoder to the embed_dim of transformer decoder.
  41. Defaults to False.
  42. transformer_decoder (:obj:`ConfigDict` or dict): Config for
  43. transformer decoder.
  44. positional_encoding (:obj:`ConfigDict` or dict): Config for
  45. transformer decoder position encoding.
  46. Defaults to `SinePositionalEncoding3D`.
  47. loss_cls (:obj:`ConfigDict` or dict): Config of the classification
  48. loss. Defaults to `CrossEntropyLoss`.
  49. loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss.
  50. Defaults to 'CrossEntropyLoss'.
  51. loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss.
  52. Defaults to 'DiceLoss'.
  53. train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
  54. Mask2Former head. Defaults to None.
  55. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  56. Mask2Former head. Defaults to None.
  57. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  58. dict], optional): Initialization config dict. Defaults to None.
  59. """
  60. def __init__(self,
  61. in_channels: List[int],
  62. feat_channels: int,
  63. out_channels: int,
  64. num_classes: int,
  65. num_frames: int = 2,
  66. num_queries: int = 100,
  67. num_transformer_feat_level: int = 3,
  68. pixel_decoder: ConfigType = ...,
  69. enforce_decoder_input_project: bool = False,
  70. transformer_decoder: ConfigType = ...,
  71. positional_encoding: ConfigType = dict(
  72. num_feats=128, normalize=True),
  73. loss_cls: ConfigType = dict(
  74. type='CrossEntropyLoss',
  75. use_sigmoid=False,
  76. loss_weight=2.0,
  77. reduction='mean',
  78. class_weight=[1.0] * 133 + [0.1]),
  79. loss_mask: ConfigType = dict(
  80. type='CrossEntropyLoss',
  81. use_sigmoid=True,
  82. reduction='mean',
  83. loss_weight=5.0),
  84. loss_dice: ConfigType = dict(
  85. type='DiceLoss',
  86. use_sigmoid=True,
  87. activate=True,
  88. reduction='mean',
  89. naive_dice=True,
  90. eps=1.0,
  91. loss_weight=5.0),
  92. train_cfg: OptConfigType = None,
  93. test_cfg: OptConfigType = None,
  94. init_cfg: OptMultiConfig = None,
  95. **kwargs) -> None:
  96. super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
  97. self.num_classes = num_classes
  98. self.num_frames = num_frames
  99. self.num_queries = num_queries
  100. self.num_transformer_feat_level = num_transformer_feat_level
  101. self.num_transformer_feat_level = num_transformer_feat_level
  102. self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads
  103. self.num_transformer_decoder_layers = transformer_decoder.num_layers
  104. assert pixel_decoder.encoder.layer_cfg. \
  105. self_attn_cfg.num_levels == num_transformer_feat_level
  106. pixel_decoder_ = copy.deepcopy(pixel_decoder)
  107. pixel_decoder_.update(
  108. in_channels=in_channels,
  109. feat_channels=feat_channels,
  110. out_channels=out_channels)
  111. self.pixel_decoder = MODELS.build(pixel_decoder_)
  112. self.transformer_decoder = Mask2FormerTransformerDecoder(
  113. **transformer_decoder)
  114. self.decoder_embed_dims = self.transformer_decoder.embed_dims
  115. self.decoder_input_projs = ModuleList()
  116. # from low resolution to high resolution
  117. for _ in range(num_transformer_feat_level):
  118. if (self.decoder_embed_dims != feat_channels
  119. or enforce_decoder_input_project):
  120. self.decoder_input_projs.append(
  121. Conv2d(
  122. feat_channels, self.decoder_embed_dims, kernel_size=1))
  123. else:
  124. self.decoder_input_projs.append(nn.Identity())
  125. self.decoder_positional_encoding = MODELS.build(positional_encoding)
  126. self.query_embed = nn.Embedding(self.num_queries, feat_channels)
  127. self.query_feat = nn.Embedding(self.num_queries, feat_channels)
  128. # from low resolution to high resolution
  129. self.level_embed = nn.Embedding(self.num_transformer_feat_level,
  130. feat_channels)
  131. self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
  132. self.mask_embed = nn.Sequential(
  133. nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
  134. nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
  135. nn.Linear(feat_channels, out_channels))
  136. self.test_cfg = test_cfg
  137. self.train_cfg = train_cfg
  138. if train_cfg:
  139. self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
  140. self.sampler = TASK_UTILS.build(
  141. # self.train_cfg.sampler, default_args=dict(context=self))
  142. self.train_cfg['sampler'],
  143. default_args=dict(context=self))
  144. self.num_points = self.train_cfg.get('num_points', 12544)
  145. self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
  146. self.importance_sample_ratio = self.train_cfg.get(
  147. 'importance_sample_ratio', 0.75)
  148. self.class_weight = loss_cls.class_weight
  149. self.loss_cls = MODELS.build(loss_cls)
  150. self.loss_mask = MODELS.build(loss_mask)
  151. self.loss_dice = MODELS.build(loss_dice)
  152. def init_weights(self) -> None:
  153. for m in self.decoder_input_projs:
  154. if isinstance(m, Conv2d):
  155. caffe2_xavier_init(m, bias=0)
  156. self.pixel_decoder.init_weights()
  157. for p in self.transformer_decoder.parameters():
  158. if p.dim() > 1:
  159. nn.init.xavier_normal_(p)
  160. def preprocess_gt(self, batch_gt_instances: InstanceList) -> InstanceList:
  161. """Preprocess the ground truth for all images.
  162. It aims to reorganize the `gt`. For example, in the
  163. `batch_data_sample.gt_instances.mask`, its shape is
  164. `(all_num_gts, h, w)`, but we don't know each gt belongs to which `img`
  165. (assume `num_frames` is 2). So, this func used to reshape the `gt_mask`
  166. to `(num_gts_per_img, num_frames, h, w)`. In addition, we can't
  167. guarantee that the number of instances in these two images is equal,
  168. so `-1` refers to nonexistent instances.
  169. Args:
  170. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  171. gt_instance. It usually includes ``labels``, each is
  172. ground truth labels of each bbox, with shape (num_gts, )
  173. and ``masks``, each is ground truth masks of each instances
  174. of an image, shape (num_gts, h, w).
  175. Returns:
  176. list[obj:`InstanceData`]: each contains the following keys
  177. - labels (Tensor): Ground truth class indices\
  178. for an image, with shape (n, ), n is the sum of\
  179. number of stuff type and number of instance in an image.
  180. - masks (Tensor): Ground truth mask for a\
  181. image, with shape (n, t, h, w).
  182. """
  183. final_batch_gt_instances = []
  184. batch_size = len(batch_gt_instances) // self.num_frames
  185. for batch_idx in range(batch_size):
  186. pair_gt_insatences = batch_gt_instances[batch_idx *
  187. self.num_frames:batch_idx *
  188. self.num_frames +
  189. self.num_frames]
  190. assert len(
  191. pair_gt_insatences
  192. ) > 1, f'mask2former for vis need multi frames to train, \
  193. but you only use {len(pair_gt_insatences)} frames'
  194. _device = pair_gt_insatences[0].labels.device
  195. for gt_instances in pair_gt_insatences:
  196. gt_instances.masks = gt_instances.masks.to_tensor(
  197. dtype=torch.bool, device=_device)
  198. all_ins_id = torch.cat([
  199. gt_instances.instances_ids
  200. for gt_instances in pair_gt_insatences
  201. ])
  202. all_ins_id = all_ins_id.unique().tolist()
  203. map_ins_id = dict()
  204. for i, ins_id in enumerate(all_ins_id):
  205. map_ins_id[ins_id] = i
  206. num_instances = len(all_ins_id)
  207. mask_shape = [
  208. num_instances, self.num_frames,
  209. pair_gt_insatences[0].masks.shape[1],
  210. pair_gt_insatences[0].masks.shape[2]
  211. ]
  212. gt_masks_per_video = torch.zeros(
  213. mask_shape, dtype=torch.bool, device=_device)
  214. gt_ids_per_video = torch.full((num_instances, self.num_frames),
  215. -1,
  216. dtype=torch.long,
  217. device=_device)
  218. gt_labels_per_video = torch.full((num_instances, ),
  219. -1,
  220. dtype=torch.long,
  221. device=_device)
  222. for frame_id in range(self.num_frames):
  223. cur_frame_gts = pair_gt_insatences[frame_id]
  224. ins_ids = cur_frame_gts.instances_ids.tolist()
  225. for i, id in enumerate(ins_ids):
  226. gt_masks_per_video[map_ins_id[id],
  227. frame_id, :, :] = cur_frame_gts.masks[i]
  228. gt_ids_per_video[map_ins_id[id],
  229. frame_id] = cur_frame_gts.instances_ids[i]
  230. gt_labels_per_video[
  231. map_ins_id[id]] = cur_frame_gts.labels[i]
  232. tmp_instances = InstanceData(
  233. labels=gt_labels_per_video,
  234. masks=gt_masks_per_video.long(),
  235. instances_id=gt_ids_per_video)
  236. final_batch_gt_instances.append(tmp_instances)
  237. return final_batch_gt_instances
  238. def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
  239. gt_instances: InstanceData,
  240. img_meta: dict) -> Tuple[Tensor]:
  241. """Compute classification and mask targets for one image.
  242. Args:
  243. cls_score (Tensor): Mask score logits from a single decoder layer
  244. for one image. Shape (num_queries, cls_out_channels).
  245. mask_pred (Tensor): Mask logits for a single decoder layer for one
  246. image. Shape (num_queries, num_frames, h, w).
  247. gt_instances (:obj:`InstanceData`): It contains ``labels`` and
  248. ``masks``.
  249. img_meta (dict): Image informtation.
  250. Returns:
  251. tuple[Tensor]: A tuple containing the following for one image.
  252. - labels (Tensor): Labels of each image. \
  253. shape (num_queries, ).
  254. - label_weights (Tensor): Label weights of each image. \
  255. shape (num_queries, ).
  256. - mask_targets (Tensor): Mask targets of each image. \
  257. shape (num_queries, num_frames, h, w).
  258. - mask_weights (Tensor): Mask weights of each image. \
  259. shape (num_queries, ).
  260. - pos_inds (Tensor): Sampled positive indices for each \
  261. image.
  262. - neg_inds (Tensor): Sampled negative indices for each \
  263. image.
  264. - sampling_result (:obj:`SamplingResult`): Sampling results.
  265. """
  266. # (num_gts, )
  267. gt_labels = gt_instances.labels
  268. # (num_gts, num_frames, h, w)
  269. gt_masks = gt_instances.masks
  270. # sample points
  271. num_queries = cls_score.shape[0]
  272. num_gts = gt_labels.shape[0]
  273. point_coords = torch.rand((1, self.num_points, 2),
  274. device=cls_score.device)
  275. # shape (num_queries, num_points)
  276. mask_points_pred = point_sample(mask_pred,
  277. point_coords.repeat(num_queries, 1,
  278. 1)).flatten(1)
  279. # shape (num_gts, num_points)
  280. gt_points_masks = point_sample(gt_masks.float(),
  281. point_coords.repeat(num_gts, 1,
  282. 1)).flatten(1)
  283. sampled_gt_instances = InstanceData(
  284. labels=gt_labels, masks=gt_points_masks)
  285. sampled_pred_instances = InstanceData(
  286. scores=cls_score, masks=mask_points_pred)
  287. # assign and sample
  288. assign_result = self.assigner.assign(
  289. pred_instances=sampled_pred_instances,
  290. gt_instances=sampled_gt_instances,
  291. img_meta=img_meta)
  292. pred_instances = InstanceData(scores=cls_score, masks=mask_pred)
  293. sampling_result = self.sampler.sample(
  294. assign_result=assign_result,
  295. pred_instances=pred_instances,
  296. gt_instances=gt_instances)
  297. pos_inds = sampling_result.pos_inds
  298. neg_inds = sampling_result.neg_inds
  299. # label target
  300. labels = gt_labels.new_full((self.num_queries, ),
  301. self.num_classes,
  302. dtype=torch.long)
  303. labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
  304. label_weights = gt_labels.new_ones((self.num_queries, ))
  305. # mask target
  306. mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
  307. mask_weights = mask_pred.new_zeros((self.num_queries, ))
  308. mask_weights[pos_inds] = 1.0
  309. return (labels, label_weights, mask_targets, mask_weights, pos_inds,
  310. neg_inds, sampling_result)
  311. def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor,
  312. batch_gt_instances: List[InstanceData],
  313. batch_img_metas: List[dict]) -> Tuple[Tensor]:
  314. """Loss function for outputs from a single decoder layer.
  315. Args:
  316. cls_scores (Tensor): Mask score logits from a single decoder layer
  317. for all images. Shape (batch_size, num_queries,
  318. cls_out_channels). Note `cls_out_channels` should include
  319. background.
  320. mask_preds (Tensor): Mask logits for a pixel decoder for all
  321. images. Shape (batch_size, num_queries, num_frames,h, w).
  322. batch_gt_instances (list[obj:`InstanceData`]): each contains
  323. ``labels`` and ``masks``.
  324. batch_img_metas (list[dict]): List of image meta information.
  325. Returns:
  326. tuple[Tensor]: Loss components for outputs from a single \
  327. decoder layer.
  328. """
  329. num_imgs = cls_scores.size(0)
  330. cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
  331. mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
  332. (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
  333. avg_factor) = self.get_targets(cls_scores_list, mask_preds_list,
  334. batch_gt_instances, batch_img_metas)
  335. # shape (batch_size, num_queries)
  336. labels = torch.stack(labels_list, dim=0)
  337. # shape (batch_size, num_queries)
  338. label_weights = torch.stack(label_weights_list, dim=0)
  339. # shape (num_total_gts, num_frames, h, w)
  340. mask_targets = torch.cat(mask_targets_list, dim=0)
  341. # shape (batch_size, num_queries)
  342. mask_weights = torch.stack(mask_weights_list, dim=0)
  343. # classfication loss
  344. # shape (batch_size * num_queries, )
  345. cls_scores = cls_scores.flatten(0, 1)
  346. labels = labels.flatten(0, 1)
  347. label_weights = label_weights.flatten(0, 1)
  348. class_weight = cls_scores.new_tensor(self.class_weight)
  349. loss_cls = self.loss_cls(
  350. cls_scores,
  351. labels,
  352. label_weights,
  353. avg_factor=class_weight[labels].sum())
  354. num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor]))
  355. num_total_masks = max(num_total_masks, 1)
  356. # extract positive ones
  357. # shape (batch_size, num_queries, num_frames, h, w)
  358. # -> (num_total_gts, num_frames, h, w)
  359. mask_preds = mask_preds[mask_weights > 0]
  360. if mask_targets.shape[0] == 0:
  361. # zero match
  362. loss_dice = mask_preds.sum()
  363. loss_mask = mask_preds.sum()
  364. return loss_cls, loss_mask, loss_dice
  365. with torch.no_grad():
  366. points_coords = get_uncertain_point_coords_with_randomness(
  367. mask_preds.flatten(0, 1).unsqueeze(1), None, self.num_points,
  368. self.oversample_ratio, self.importance_sample_ratio)
  369. # shape (num_total_gts * num_frames, h, w) ->
  370. # (num_total_gts, num_points)
  371. mask_point_targets = point_sample(
  372. mask_targets.flatten(0, 1).unsqueeze(1).float(),
  373. points_coords).squeeze(1)
  374. # shape (num_total_gts * num_frames, num_points)
  375. mask_point_preds = point_sample(
  376. mask_preds.flatten(0, 1).unsqueeze(1), points_coords).squeeze(1)
  377. # dice loss
  378. loss_dice = self.loss_dice(
  379. mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
  380. # mask loss
  381. # shape (num_total_gts * num_frames, num_points) ->
  382. # (num_total_gts * num_frames * num_points, )
  383. mask_point_preds = mask_point_preds.reshape(-1)
  384. # shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
  385. mask_point_targets = mask_point_targets.reshape(-1)
  386. loss_mask = self.loss_mask(
  387. mask_point_preds,
  388. mask_point_targets,
  389. avg_factor=num_total_masks * self.num_points / self.num_frames)
  390. return loss_cls, loss_mask, loss_dice
  391. def _forward_head(
  392. self, decoder_out: Tensor, mask_feature: Tensor,
  393. attn_mask_target_size: Tuple[int,
  394. int]) -> Tuple[Tensor, Tensor, Tensor]:
  395. """Forward for head part which is called after every decoder layer.
  396. Args:
  397. decoder_out (Tensor): in shape (num_queries, batch_size, c).
  398. mask_feature (Tensor): in shape (batch_size, t, c, h, w).
  399. attn_mask_target_size (tuple[int, int]): target attention
  400. mask size.
  401. Returns:
  402. tuple: A tuple contain three elements.
  403. - cls_pred (Tensor): Classification scores in shape \
  404. (batch_size, num_queries, cls_out_channels). \
  405. Note `cls_out_channels` should include background.
  406. - mask_pred (Tensor): Mask scores in shape \
  407. (batch_size, num_queries,h, w).
  408. - attn_mask (Tensor): Attention mask in shape \
  409. (batch_size * num_heads, num_queries, h, w).
  410. """
  411. decoder_out = self.transformer_decoder.post_norm(decoder_out)
  412. cls_pred = self.cls_embed(decoder_out)
  413. mask_embed = self.mask_embed(decoder_out)
  414. # shape (batch_size, num_queries, t, h, w)
  415. mask_pred = torch.einsum('bqc,btchw->bqthw', mask_embed, mask_feature)
  416. b, q, t, _, _ = mask_pred.shape
  417. attn_mask = F.interpolate(
  418. mask_pred.flatten(0, 1),
  419. attn_mask_target_size,
  420. mode='bilinear',
  421. align_corners=False).view(b, q, t, attn_mask_target_size[0],
  422. attn_mask_target_size[1])
  423. # shape (batch_size, num_queries, t, h, w) ->
  424. # (batch_size, num_queries, t*h*w) ->
  425. # (batch_size, num_head, num_queries, t*h*w) ->
  426. # (batch_size*num_head, num_queries, t*h*w)
  427. attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat(
  428. (1, self.num_heads, 1, 1)).flatten(0, 1)
  429. attn_mask = attn_mask.sigmoid() < 0.5
  430. attn_mask = attn_mask.detach()
  431. return cls_pred, mask_pred, attn_mask
  432. def forward(
  433. self, x: List[Tensor], data_samples: TrackDataSample
  434. ) -> Tuple[List[Tensor], List[Tensor]]:
  435. """Forward function.
  436. Args:
  437. x (list[Tensor]): Multi scale Features from the
  438. upstream network, each is a 4D-tensor.
  439. data_samples (List[:obj:`TrackDataSample`]): The Data
  440. Samples. It usually includes information such as `gt_instance`.
  441. Returns:
  442. tuple[list[Tensor]]: A tuple contains two elements.
  443. - cls_pred_list (list[Tensor)]: Classification logits \
  444. for each decoder layer. Each is a 3D-tensor with shape \
  445. (batch_size, num_queries, cls_out_channels). \
  446. Note `cls_out_channels` should include background.
  447. - mask_pred_list (list[Tensor]): Mask logits for each \
  448. decoder layer. Each with shape (batch_size, num_queries, \
  449. h, w).
  450. """
  451. mask_features, multi_scale_memorys = self.pixel_decoder(x)
  452. bt, c_m, h_m, w_m = mask_features.shape
  453. batch_size = bt // self.num_frames if self.training else 1
  454. t = bt // batch_size
  455. mask_features = mask_features.view(batch_size, t, c_m, h_m, w_m)
  456. # multi_scale_memorys (from low resolution to high resolution)
  457. decoder_inputs = []
  458. decoder_positional_encodings = []
  459. for i in range(self.num_transformer_feat_level):
  460. decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
  461. decoder_input = decoder_input.flatten(2)
  462. level_embed = self.level_embed.weight[i][None, :, None]
  463. decoder_input = decoder_input + level_embed
  464. _, c, hw = decoder_input.shape
  465. # shape (batch_size*t, c, h, w) ->
  466. # (batch_size, t, c, hw) ->
  467. # (batch_size, t*h*w, c)
  468. decoder_input = decoder_input.view(batch_size, t, c,
  469. hw).permute(0, 1, 3,
  470. 2).flatten(1, 2)
  471. # shape (batch_size, c, h, w) -> (h*w, batch_size, c)
  472. mask = decoder_input.new_zeros(
  473. (batch_size, t) + multi_scale_memorys[i].shape[-2:],
  474. dtype=torch.bool)
  475. decoder_positional_encoding = self.decoder_positional_encoding(
  476. mask)
  477. decoder_positional_encoding = decoder_positional_encoding.flatten(
  478. 3).permute(0, 1, 3, 2).flatten(1, 2)
  479. decoder_inputs.append(decoder_input)
  480. decoder_positional_encodings.append(decoder_positional_encoding)
  481. # shape (num_queries, c) -> (batch_size, num_queries, c)
  482. query_feat = self.query_feat.weight.unsqueeze(0).repeat(
  483. (batch_size, 1, 1))
  484. query_embed = self.query_embed.weight.unsqueeze(0).repeat(
  485. (batch_size, 1, 1))
  486. cls_pred_list = []
  487. mask_pred_list = []
  488. cls_pred, mask_pred, attn_mask = self._forward_head(
  489. query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
  490. cls_pred_list.append(cls_pred)
  491. mask_pred_list.append(mask_pred)
  492. for i in range(self.num_transformer_decoder_layers):
  493. level_idx = i % self.num_transformer_feat_level
  494. # if a mask is all True(all background), then set it all False.
  495. attn_mask[torch.where(
  496. attn_mask.sum(-1) == attn_mask.shape[-1])] = False
  497. # cross_attn + self_attn
  498. layer = self.transformer_decoder.layers[i]
  499. query_feat = layer(
  500. query=query_feat,
  501. key=decoder_inputs[level_idx],
  502. value=decoder_inputs[level_idx],
  503. query_pos=query_embed,
  504. key_pos=decoder_positional_encodings[level_idx],
  505. cross_attn_mask=attn_mask,
  506. query_key_padding_mask=None,
  507. # here we do not apply masking on padded region
  508. key_padding_mask=None)
  509. cls_pred, mask_pred, attn_mask = self._forward_head(
  510. query_feat, mask_features, multi_scale_memorys[
  511. (i + 1) % self.num_transformer_feat_level].shape[-2:])
  512. cls_pred_list.append(cls_pred)
  513. mask_pred_list.append(mask_pred)
  514. return cls_pred_list, mask_pred_list
  515. def loss(
  516. self,
  517. x: Tuple[Tensor],
  518. data_samples: TrackSampleList,
  519. ) -> Dict[str, Tensor]:
  520. """Perform forward propagation and loss calculation of the track head
  521. on the features of the upstream network.
  522. Args:
  523. x (tuple[Tensor]): Multi-level features from the upstream
  524. network, each is a 4D-tensor.
  525. data_samples (List[:obj:`TrackDataSample`]): The Data
  526. Samples. It usually includes information such as `gt_instance`.
  527. Returns:
  528. dict[str, Tensor]: a dictionary of loss components
  529. """
  530. batch_img_metas = []
  531. batch_gt_instances = []
  532. for data_sample in data_samples:
  533. video_img_metas = defaultdict(list)
  534. for image_idx in range(len(data_sample)):
  535. batch_gt_instances.append(data_sample[image_idx].gt_instances)
  536. for key, value in data_sample[image_idx].metainfo.items():
  537. video_img_metas[key].append(value)
  538. batch_img_metas.append(video_img_metas)
  539. # forward
  540. all_cls_scores, all_mask_preds = self(x, data_samples)
  541. # preprocess ground truth
  542. batch_gt_instances = self.preprocess_gt(batch_gt_instances)
  543. # loss
  544. losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
  545. batch_gt_instances, batch_img_metas)
  546. return losses
  547. def predict(self,
  548. x: Tuple[Tensor],
  549. data_samples: TrackDataSample,
  550. rescale: bool = True) -> InstanceList:
  551. """Test without augmentation.
  552. Args:
  553. x (tuple[Tensor]): Multi-level features from the
  554. upstream network, each is a 4D-tensor.
  555. data_samples (List[:obj:`TrackDataSample`]): The Data
  556. Samples. It usually includes information such as `gt_instance`.
  557. rescale (bool, Optional): If False, then returned bboxes and masks
  558. will fit the scale of img, otherwise, returned bboxes and masks
  559. will fit the scale of original image shape. Defaults to True.
  560. Returns:
  561. list[obj:`InstanceData`]: each contains the following keys
  562. - labels (Tensor): Prediction class indices\
  563. for an image, with shape (n, ), n is the sum of\
  564. number of stuff type and number of instance in an image.
  565. - masks (Tensor): Prediction mask for a\
  566. image, with shape (n, t, h, w).
  567. """
  568. batch_img_metas = [
  569. data_samples[img_idx].metainfo
  570. for img_idx in range(len(data_samples))
  571. ]
  572. all_cls_scores, all_mask_preds = self(x, data_samples)
  573. mask_cls_results = all_cls_scores[-1]
  574. mask_pred_results = all_mask_preds[-1]
  575. mask_cls_results = mask_cls_results[0]
  576. # upsample masks
  577. img_shape = batch_img_metas[0]['batch_input_shape']
  578. mask_pred_results = F.interpolate(
  579. mask_pred_results[0],
  580. size=(img_shape[0], img_shape[1]),
  581. mode='bilinear',
  582. align_corners=False)
  583. results = self.predict_by_feat(mask_cls_results, mask_pred_results,
  584. batch_img_metas)
  585. return results
  586. def predict_by_feat(self,
  587. mask_cls_results: List[Tensor],
  588. mask_pred_results: List[Tensor],
  589. batch_img_metas: List[dict],
  590. rescale: bool = True) -> InstanceList:
  591. """Get top-10 predictions.
  592. Args:
  593. mask_cls_results (Tensor): Mask classification logits,\
  594. shape (batch_size, num_queries, cls_out_channels).
  595. Note `cls_out_channels` should include background.
  596. mask_pred_results (Tensor): Mask logits, shape \
  597. (batch_size, num_queries, h, w).
  598. batch_img_metas (list[dict]): List of image meta information.
  599. rescale (bool, Optional): If False, then returned bboxes and masks
  600. will fit the scale of img, otherwise, returned bboxes and masks
  601. will fit the scale of original image shape. Defaults to True.
  602. Returns:
  603. list[obj:`InstanceData`]: each contains the following keys
  604. - labels (Tensor): Prediction class indices\
  605. for an image, with shape (n, ), n is the sum of\
  606. number of stuff type and number of instance in an image.
  607. - masks (Tensor): Prediction mask for a\
  608. image, with shape (n, t, h, w).
  609. """
  610. results = []
  611. if len(mask_cls_results) > 0:
  612. scores = F.softmax(mask_cls_results, dim=-1)[:, :-1]
  613. labels = torch.arange(self.num_classes).unsqueeze(0).repeat(
  614. self.num_queries, 1).flatten(0, 1).to(scores.device)
  615. # keep top-10 predictions
  616. scores_per_image, topk_indices = scores.flatten(0, 1).topk(
  617. 10, sorted=False)
  618. labels_per_image = labels[topk_indices]
  619. topk_indices = topk_indices // self.num_classes
  620. mask_pred_results = mask_pred_results[topk_indices]
  621. img_shape = batch_img_metas[0]['img_shape']
  622. mask_pred_results = \
  623. mask_pred_results[:, :, :img_shape[0], :img_shape[1]]
  624. if rescale:
  625. # return result in original resolution
  626. ori_height, ori_width = batch_img_metas[0]['ori_shape'][:2]
  627. mask_pred_results = F.interpolate(
  628. mask_pred_results,
  629. size=(ori_height, ori_width),
  630. mode='bilinear',
  631. align_corners=False)
  632. masks = mask_pred_results > 0.
  633. # format top-10 predictions
  634. for img_idx in range(len(batch_img_metas)):
  635. pred_track_instances = InstanceData()
  636. pred_track_instances.masks = masks[:, img_idx]
  637. pred_track_instances.bboxes = mask2bbox(masks[:, img_idx])
  638. pred_track_instances.labels = labels_per_image
  639. pred_track_instances.scores = scores_per_image
  640. pred_track_instances.instances_id = torch.arange(10)
  641. results.append(pred_track_instances)
  642. return results