unified_head.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import copy
  2. from typing import Sequence
  3. import torch
  4. from mmengine.structures import InstanceData, PixelData
  5. from torch import nn
  6. from torch.nn import functional as F
  7. from mmdet.evaluation.functional import INSTANCE_OFFSET
  8. from mmdet.registry import MODELS
  9. from .utils import (is_lower_torch_version, retry_if_cuda_oom,
  10. sem_seg_postprocess)
  11. @MODELS.register_module()
  12. class XDecoderUnifiedhead(nn.Module):
  13. def __init__(self,
  14. in_channels: int,
  15. pixel_decoder: nn.Module,
  16. transformer_decoder: nn.Module,
  17. task: str = 'semseg',
  18. test_cfg=None):
  19. super().__init__()
  20. self.task = task
  21. self.test_cfg = test_cfg
  22. pixel_decoder_ = copy.deepcopy(pixel_decoder)
  23. pixel_decoder_.update(in_channels=in_channels)
  24. self.pixel_decoder = MODELS.build(pixel_decoder_)
  25. transformer_decoder_ = copy.deepcopy(transformer_decoder)
  26. transformer_decoder_.update(task=task)
  27. self.predictor = MODELS.build(transformer_decoder_)
  28. self.return_inter_mask = False
  29. if self.task == 'ref-caption':
  30. # ref-caption = ref-seg + caption,
  31. # so we need to return the intermediate mask
  32. self.return_inter_mask = True
  33. self._all_text_prompts = None
  34. self._extra = None
  35. # TODO: Very trick, for retrieval task
  36. self._force_not_use_cache = False
  37. def pre_process(self, batch_data_samples, device):
  38. extra = {}
  39. if self.task != 'caption':
  40. # have text
  41. all_text_prompts = []
  42. num_thing_class = 0
  43. for data_samples in batch_data_samples:
  44. if isinstance(data_samples.text, str):
  45. text = data_samples.text.split('.')
  46. elif isinstance(data_samples.text, Sequence):
  47. text = data_samples.text
  48. else:
  49. raise TypeError(
  50. 'Type pf data_sample.text must be sequence or str')
  51. text = list(filter(lambda x: len(x) > 0, text))
  52. all_text_prompts.append(text)
  53. num_thing_class = len(text)
  54. # for panoptic
  55. if 'stuff_text' in data_samples:
  56. if isinstance(data_samples.stuff_text, str):
  57. text = data_samples.stuff_text.split('.')
  58. elif isinstance(data_samples.stuff_text, Sequence):
  59. text = data_samples.stuff_text
  60. else:
  61. raise TypeError('Type pf data_sample.stuff_text '
  62. 'must be sequence or str')
  63. text = list(filter(lambda x: len(x) > 0, text))
  64. all_text_prompts[-1].extend(text)
  65. # TODO: support batch
  66. all_text_prompts = all_text_prompts[0]
  67. if all_text_prompts != self._all_text_prompts \
  68. or self._force_not_use_cache:
  69. # avoid redundant computation
  70. self._all_text_prompts = all_text_prompts
  71. if self.task in ['semseg', 'instance', 'panoptic']:
  72. self.predictor.lang_encoder.get_mean_embeds(
  73. all_text_prompts + ['background'])
  74. elif self.task == 'ref-seg':
  75. token_info = self.predictor.lang_encoder.get_text_embeds(
  76. all_text_prompts, norm=False)
  77. token_emb = token_info['token_emb']
  78. tokens = token_info['tokens']
  79. query_emb = token_emb[tokens['attention_mask'].bool()]
  80. extra['grounding_tokens'] = query_emb[:, None]
  81. extra['class_emb'] = token_info['class_emb']
  82. elif self.task == 'retrieval':
  83. token_info = self.predictor.lang_encoder.get_text_embeds(
  84. all_text_prompts, norm=True)
  85. extra['class_emb'] = token_info['class_emb']
  86. self._extra = extra
  87. return extra, all_text_prompts, num_thing_class
  88. else:
  89. return self._extra, all_text_prompts, num_thing_class
  90. else:
  91. if not hasattr(self, 'start_token'):
  92. self.start_token = self.predictor.lang_encoder. \
  93. get_sot_token(device=device)
  94. extra['start_token'] = self.start_token
  95. return extra, None, None
  96. def predict(self, features, batch_data_samples):
  97. # multi scale feature
  98. mask_features, multi_scale_features = self.pixel_decoder(features)
  99. # pre process
  100. extra, all_text_prompts, num_thing_class = self.pre_process(
  101. batch_data_samples, mask_features.device)
  102. # transformer decoder forward
  103. predictions = self.predictor(
  104. multi_scale_features, mask_features, extra=extra)
  105. # post process
  106. return self.post_process(predictions, batch_data_samples,
  107. all_text_prompts, num_thing_class)
  108. def post_process(self, predictions, batch_data_samples, all_text_prompts,
  109. num_thing_class):
  110. batch_img_metas = [
  111. data_samples.metainfo for data_samples in batch_data_samples
  112. ]
  113. batch_input_shape = batch_data_samples[0].metainfo['batch_input_shape']
  114. if self.task == 'caption':
  115. for text, data_samples in zip(predictions['pred_caption'],
  116. batch_data_samples):
  117. data_samples.pred_caption = text
  118. if 'pred_instances' in batch_data_samples[0]:
  119. for img_metas, data_samples in zip(batch_img_metas,
  120. batch_data_samples):
  121. original_caption = data_samples.text.split('.')
  122. text_prompts = list(
  123. filter(lambda x: len(x) > 0, original_caption))
  124. height = img_metas['ori_shape'][0]
  125. width = img_metas['ori_shape'][1]
  126. image_size = img_metas['grounding_img_shape'][:2]
  127. mask_pred_result = data_samples.pred_instances.masks.float(
  128. )
  129. mask_cls_result = data_samples.pred_instances.scores.float(
  130. )
  131. mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
  132. mask_pred_result, image_size, height, width)
  133. pred_instances = retry_if_cuda_oom(
  134. self._instance_inference)(mask_cls_result,
  135. mask_pred_result,
  136. text_prompts)
  137. data_samples.pred_instances = pred_instances
  138. elif self.task in ['semseg', 'instance', 'panoptic']:
  139. mask_pred_results = predictions['pred_masks']
  140. mask_cls_results = predictions['pred_logits']
  141. if is_lower_torch_version():
  142. mask_pred_results = F.interpolate(
  143. mask_pred_results,
  144. size=(batch_input_shape[-2], batch_input_shape[-1]),
  145. mode='bicubic',
  146. align_corners=False)
  147. else:
  148. mask_pred_results = F.interpolate(
  149. mask_pred_results,
  150. size=(batch_input_shape[-2], batch_input_shape[-1]),
  151. mode='bicubic',
  152. align_corners=False,
  153. antialias=True)
  154. # for batch
  155. for mask_cls_result, \
  156. mask_pred_result, \
  157. img_metas, \
  158. data_samples in zip(
  159. mask_cls_results,
  160. mask_pred_results,
  161. batch_img_metas,
  162. batch_data_samples):
  163. height = img_metas['ori_shape'][0]
  164. width = img_metas['ori_shape'][1]
  165. image_size = img_metas['img_shape'][:2]
  166. mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
  167. mask_pred_result, image_size, height, width)
  168. mask_cls_result = mask_cls_result.to(mask_pred_result)
  169. if self.task == 'semseg':
  170. pred_sem_seg = retry_if_cuda_oom(self._semantic_inference)(
  171. mask_cls_result, mask_pred_result, all_text_prompts)
  172. data_samples.pred_sem_seg = pred_sem_seg
  173. elif self.task == 'instance':
  174. pred_instances = retry_if_cuda_oom(
  175. self._instance_inference)(mask_cls_result,
  176. mask_pred_result,
  177. all_text_prompts)
  178. data_samples.pred_instances = pred_instances
  179. elif self.task == 'panoptic':
  180. pred_panoptic_seg = retry_if_cuda_oom(
  181. self._panoptic_inference)(mask_cls_result,
  182. mask_pred_result,
  183. all_text_prompts,
  184. num_thing_class)
  185. data_samples.pred_panoptic_seg = pred_panoptic_seg
  186. elif self.task == 'ref-seg':
  187. mask_pred_results = predictions['pred_masks']
  188. mask_cls_results = predictions['pred_logits']
  189. results_ = zip(mask_pred_results, mask_cls_results,
  190. batch_img_metas, batch_data_samples)
  191. for mask_pred_result, mask_cls_result, \
  192. img_metas, data_samples in results_:
  193. if is_lower_torch_version():
  194. mask_pred_result = F.interpolate(
  195. mask_pred_result[None],
  196. size=(batch_input_shape[-2], batch_input_shape[-1]),
  197. mode='bicubic',
  198. align_corners=False)[0]
  199. else:
  200. mask_pred_result = F.interpolate(
  201. mask_pred_result[None],
  202. size=(batch_input_shape[-2], batch_input_shape[-1]),
  203. mode='bicubic',
  204. align_corners=False,
  205. antialias=True)[0]
  206. if self.return_inter_mask:
  207. mask = mask_pred_result > 0
  208. pred_instances = InstanceData()
  209. pred_instances.masks = mask
  210. pred_instances.scores = mask_cls_result
  211. data_samples.pred_instances = pred_instances
  212. continue
  213. height = img_metas['ori_shape'][0]
  214. width = img_metas['ori_shape'][1]
  215. image_size = img_metas['img_shape'][:2]
  216. mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
  217. mask_pred_result, image_size, height, width)
  218. pred_instances = retry_if_cuda_oom(self._instance_inference)(
  219. mask_cls_result, mask_pred_result, all_text_prompts)
  220. data_samples.pred_instances = pred_instances
  221. elif self.task == 'retrieval':
  222. batch_data_samples[0].pred_score = predictions['pred_logits']
  223. return batch_data_samples
  224. def _instance_inference(self, mask_cls, mask_pred, text_prompts):
  225. num_class = len(text_prompts)
  226. if self.task in ['ref-seg', 'caption']:
  227. scores = F.softmax(mask_cls, dim=-1)
  228. scores_per_image = scores.max(dim=-1)[0]
  229. labels_per_image = torch.arange(num_class)
  230. else:
  231. scores = F.softmax(mask_cls, dim=-1)[:, :-1]
  232. labels = torch.arange(
  233. num_class,
  234. device=scores.device).unsqueeze(0).repeat(scores.shape[0],
  235. 1).flatten(0, 1)
  236. scores_per_image, topk_indices = scores.flatten(0, 1).topk(
  237. self.test_cfg.get('max_per_img', 100), sorted=False)
  238. labels_per_image = labels[topk_indices]
  239. topk_indices = (topk_indices // num_class)
  240. mask_pred = mask_pred[topk_indices]
  241. result = InstanceData()
  242. mask_pred = mask_pred.sigmoid()
  243. result.masks = (mask_pred > self.test_cfg.mask_thr).float()
  244. # calculate average mask prob
  245. mask_scores_per_image = (mask_pred.flatten(1) *
  246. result.masks.flatten(1)).sum(1) / (
  247. result.masks.flatten(1).sum(1) + 1e-6)
  248. result.scores = scores_per_image * mask_scores_per_image
  249. result.labels = labels_per_image
  250. result.label_names = [
  251. text_prompts[label] for label in labels_per_image
  252. ]
  253. result.bboxes = result.scores.new_zeros(len(result.scores), 4)
  254. return result
  255. def _semantic_inference(self, mask_cls, mask_pred, text_prompts):
  256. mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
  257. mask_pred = mask_pred.sigmoid()
  258. sem_seg = torch.einsum('qc,qhw->chw', mask_cls, mask_pred)
  259. if sem_seg.shape[0] == 1:
  260. # 0 is foreground, ignore_index is background
  261. sem_seg = (sem_seg.squeeze(0) <= self.test_cfg.mask_thr).int()
  262. sem_seg[sem_seg == 1] = self.test_cfg.get('ignore_index', 255)
  263. else:
  264. # 0 is foreground, ignore_index is background
  265. if self.test_cfg.use_thr_for_mc:
  266. foreground_flag = sem_seg > self.test_cfg.mask_thr
  267. sem_seg = sem_seg.max(0)[1]
  268. sem_seg[foreground_flag.sum(0) == 0] = self.test_cfg.get(
  269. 'ignore_index', 255)
  270. else:
  271. sem_seg = sem_seg.max(0)[1]
  272. pred_sem_seg = PixelData(
  273. sem_seg=sem_seg[None],
  274. metainfo={
  275. 'label_names': text_prompts,
  276. 'ignore_index': self.test_cfg.get('ignore_index', 255)
  277. })
  278. return pred_sem_seg
  279. def _panoptic_inference(self, mask_cls, mask_pred, all_text_prompts,
  280. num_thing_class):
  281. scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
  282. mask_pred = mask_pred.sigmoid()
  283. keep = labels.ne(len(all_text_prompts)) & (
  284. scores > self.test_cfg.mask_thr)
  285. cur_scores = scores[keep]
  286. cur_classes = labels[keep]
  287. cur_masks = mask_pred[keep]
  288. cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
  289. h, w = cur_masks.shape[-2:]
  290. panoptic_seg = torch.full((h, w),
  291. self.test_cfg.get('ignore_index', 255),
  292. dtype=torch.int32,
  293. device=cur_masks.device)
  294. instance_id = 1
  295. if cur_masks.shape[0] > 0:
  296. cur_mask_ids = cur_prob_masks.argmax(0)
  297. for k in range(cur_classes.shape[0]):
  298. pred_class = cur_classes[k].item()
  299. isthing = int(pred_class) < num_thing_class
  300. mask_area = (cur_mask_ids == k).sum().item()
  301. original_area = (cur_masks[k] >= 0.5).sum().item()
  302. mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
  303. if mask_area > 0 and original_area > 0 and mask.sum().item(
  304. ) > 0:
  305. if mask_area / original_area < self.test_cfg.overlap_thr:
  306. continue
  307. # merge stuff regions
  308. if not isthing:
  309. panoptic_seg[mask] = int(pred_class)
  310. else:
  311. panoptic_seg[mask] = int(
  312. pred_class) + instance_id * INSTANCE_OFFSET
  313. instance_id += 1
  314. panoptic_seg = PixelData(
  315. sem_seg=panoptic_seg[None],
  316. metainfo={
  317. 'label_names': all_text_prompts,
  318. 'ignore_index': self.test_cfg.get('ignore_index', 255)
  319. })
  320. return panoptic_seg