123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363 |
- import copy
- from typing import Sequence
- import torch
- from mmengine.structures import InstanceData, PixelData
- from torch import nn
- from torch.nn import functional as F
- from mmdet.evaluation.functional import INSTANCE_OFFSET
- from mmdet.registry import MODELS
- from .utils import (is_lower_torch_version, retry_if_cuda_oom,
- sem_seg_postprocess)
- @MODELS.register_module()
- class XDecoderUnifiedhead(nn.Module):
- def __init__(self,
- in_channels: int,
- pixel_decoder: nn.Module,
- transformer_decoder: nn.Module,
- task: str = 'semseg',
- test_cfg=None):
- super().__init__()
- self.task = task
- self.test_cfg = test_cfg
- pixel_decoder_ = copy.deepcopy(pixel_decoder)
- pixel_decoder_.update(in_channels=in_channels)
- self.pixel_decoder = MODELS.build(pixel_decoder_)
- transformer_decoder_ = copy.deepcopy(transformer_decoder)
- transformer_decoder_.update(task=task)
- self.predictor = MODELS.build(transformer_decoder_)
- self.return_inter_mask = False
- if self.task == 'ref-caption':
- # ref-caption = ref-seg + caption,
- # so we need to return the intermediate mask
- self.return_inter_mask = True
- self._all_text_prompts = None
- self._extra = None
- # TODO: Very trick, for retrieval task
- self._force_not_use_cache = False
- def pre_process(self, batch_data_samples, device):
- extra = {}
- if self.task != 'caption':
- # have text
- all_text_prompts = []
- num_thing_class = 0
- for data_samples in batch_data_samples:
- if isinstance(data_samples.text, str):
- text = data_samples.text.split('.')
- elif isinstance(data_samples.text, Sequence):
- text = data_samples.text
- else:
- raise TypeError(
- 'Type pf data_sample.text must be sequence or str')
- text = list(filter(lambda x: len(x) > 0, text))
- all_text_prompts.append(text)
- num_thing_class = len(text)
- # for panoptic
- if 'stuff_text' in data_samples:
- if isinstance(data_samples.stuff_text, str):
- text = data_samples.stuff_text.split('.')
- elif isinstance(data_samples.stuff_text, Sequence):
- text = data_samples.stuff_text
- else:
- raise TypeError('Type pf data_sample.stuff_text '
- 'must be sequence or str')
- text = list(filter(lambda x: len(x) > 0, text))
- all_text_prompts[-1].extend(text)
- # TODO: support batch
- all_text_prompts = all_text_prompts[0]
- if all_text_prompts != self._all_text_prompts \
- or self._force_not_use_cache:
- # avoid redundant computation
- self._all_text_prompts = all_text_prompts
- if self.task in ['semseg', 'instance', 'panoptic']:
- self.predictor.lang_encoder.get_mean_embeds(
- all_text_prompts + ['background'])
- elif self.task == 'ref-seg':
- token_info = self.predictor.lang_encoder.get_text_embeds(
- all_text_prompts, norm=False)
- token_emb = token_info['token_emb']
- tokens = token_info['tokens']
- query_emb = token_emb[tokens['attention_mask'].bool()]
- extra['grounding_tokens'] = query_emb[:, None]
- extra['class_emb'] = token_info['class_emb']
- elif self.task == 'retrieval':
- token_info = self.predictor.lang_encoder.get_text_embeds(
- all_text_prompts, norm=True)
- extra['class_emb'] = token_info['class_emb']
- self._extra = extra
- return extra, all_text_prompts, num_thing_class
- else:
- return self._extra, all_text_prompts, num_thing_class
- else:
- if not hasattr(self, 'start_token'):
- self.start_token = self.predictor.lang_encoder. \
- get_sot_token(device=device)
- extra['start_token'] = self.start_token
- return extra, None, None
- def predict(self, features, batch_data_samples):
- # multi scale feature
- mask_features, multi_scale_features = self.pixel_decoder(features)
- # pre process
- extra, all_text_prompts, num_thing_class = self.pre_process(
- batch_data_samples, mask_features.device)
- # transformer decoder forward
- predictions = self.predictor(
- multi_scale_features, mask_features, extra=extra)
- # post process
- return self.post_process(predictions, batch_data_samples,
- all_text_prompts, num_thing_class)
- def post_process(self, predictions, batch_data_samples, all_text_prompts,
- num_thing_class):
- batch_img_metas = [
- data_samples.metainfo for data_samples in batch_data_samples
- ]
- batch_input_shape = batch_data_samples[0].metainfo['batch_input_shape']
- if self.task == 'caption':
- for text, data_samples in zip(predictions['pred_caption'],
- batch_data_samples):
- data_samples.pred_caption = text
- if 'pred_instances' in batch_data_samples[0]:
- for img_metas, data_samples in zip(batch_img_metas,
- batch_data_samples):
- original_caption = data_samples.text.split('.')
- text_prompts = list(
- filter(lambda x: len(x) > 0, original_caption))
- height = img_metas['ori_shape'][0]
- width = img_metas['ori_shape'][1]
- image_size = img_metas['grounding_img_shape'][:2]
- mask_pred_result = data_samples.pred_instances.masks.float(
- )
- mask_cls_result = data_samples.pred_instances.scores.float(
- )
- mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
- mask_pred_result, image_size, height, width)
- pred_instances = retry_if_cuda_oom(
- self._instance_inference)(mask_cls_result,
- mask_pred_result,
- text_prompts)
- data_samples.pred_instances = pred_instances
- elif self.task in ['semseg', 'instance', 'panoptic']:
- mask_pred_results = predictions['pred_masks']
- mask_cls_results = predictions['pred_logits']
- if is_lower_torch_version():
- mask_pred_results = F.interpolate(
- mask_pred_results,
- size=(batch_input_shape[-2], batch_input_shape[-1]),
- mode='bicubic',
- align_corners=False)
- else:
- mask_pred_results = F.interpolate(
- mask_pred_results,
- size=(batch_input_shape[-2], batch_input_shape[-1]),
- mode='bicubic',
- align_corners=False,
- antialias=True)
- # for batch
- for mask_cls_result, \
- mask_pred_result, \
- img_metas, \
- data_samples in zip(
- mask_cls_results,
- mask_pred_results,
- batch_img_metas,
- batch_data_samples):
- height = img_metas['ori_shape'][0]
- width = img_metas['ori_shape'][1]
- image_size = img_metas['img_shape'][:2]
- mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
- mask_pred_result, image_size, height, width)
- mask_cls_result = mask_cls_result.to(mask_pred_result)
- if self.task == 'semseg':
- pred_sem_seg = retry_if_cuda_oom(self._semantic_inference)(
- mask_cls_result, mask_pred_result, all_text_prompts)
- data_samples.pred_sem_seg = pred_sem_seg
- elif self.task == 'instance':
- pred_instances = retry_if_cuda_oom(
- self._instance_inference)(mask_cls_result,
- mask_pred_result,
- all_text_prompts)
- data_samples.pred_instances = pred_instances
- elif self.task == 'panoptic':
- pred_panoptic_seg = retry_if_cuda_oom(
- self._panoptic_inference)(mask_cls_result,
- mask_pred_result,
- all_text_prompts,
- num_thing_class)
- data_samples.pred_panoptic_seg = pred_panoptic_seg
- elif self.task == 'ref-seg':
- mask_pred_results = predictions['pred_masks']
- mask_cls_results = predictions['pred_logits']
- results_ = zip(mask_pred_results, mask_cls_results,
- batch_img_metas, batch_data_samples)
- for mask_pred_result, mask_cls_result, \
- img_metas, data_samples in results_:
- if is_lower_torch_version():
- mask_pred_result = F.interpolate(
- mask_pred_result[None],
- size=(batch_input_shape[-2], batch_input_shape[-1]),
- mode='bicubic',
- align_corners=False)[0]
- else:
- mask_pred_result = F.interpolate(
- mask_pred_result[None],
- size=(batch_input_shape[-2], batch_input_shape[-1]),
- mode='bicubic',
- align_corners=False,
- antialias=True)[0]
- if self.return_inter_mask:
- mask = mask_pred_result > 0
- pred_instances = InstanceData()
- pred_instances.masks = mask
- pred_instances.scores = mask_cls_result
- data_samples.pred_instances = pred_instances
- continue
- height = img_metas['ori_shape'][0]
- width = img_metas['ori_shape'][1]
- image_size = img_metas['img_shape'][:2]
- mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
- mask_pred_result, image_size, height, width)
- pred_instances = retry_if_cuda_oom(self._instance_inference)(
- mask_cls_result, mask_pred_result, all_text_prompts)
- data_samples.pred_instances = pred_instances
- elif self.task == 'retrieval':
- batch_data_samples[0].pred_score = predictions['pred_logits']
- return batch_data_samples
- def _instance_inference(self, mask_cls, mask_pred, text_prompts):
- num_class = len(text_prompts)
- if self.task in ['ref-seg', 'caption']:
- scores = F.softmax(mask_cls, dim=-1)
- scores_per_image = scores.max(dim=-1)[0]
- labels_per_image = torch.arange(num_class)
- else:
- scores = F.softmax(mask_cls, dim=-1)[:, :-1]
- labels = torch.arange(
- num_class,
- device=scores.device).unsqueeze(0).repeat(scores.shape[0],
- 1).flatten(0, 1)
- scores_per_image, topk_indices = scores.flatten(0, 1).topk(
- self.test_cfg.get('max_per_img', 100), sorted=False)
- labels_per_image = labels[topk_indices]
- topk_indices = (topk_indices // num_class)
- mask_pred = mask_pred[topk_indices]
- result = InstanceData()
- mask_pred = mask_pred.sigmoid()
- result.masks = (mask_pred > self.test_cfg.mask_thr).float()
- # calculate average mask prob
- mask_scores_per_image = (mask_pred.flatten(1) *
- result.masks.flatten(1)).sum(1) / (
- result.masks.flatten(1).sum(1) + 1e-6)
- result.scores = scores_per_image * mask_scores_per_image
- result.labels = labels_per_image
- result.label_names = [
- text_prompts[label] for label in labels_per_image
- ]
- result.bboxes = result.scores.new_zeros(len(result.scores), 4)
- return result
- def _semantic_inference(self, mask_cls, mask_pred, text_prompts):
- mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
- mask_pred = mask_pred.sigmoid()
- sem_seg = torch.einsum('qc,qhw->chw', mask_cls, mask_pred)
- if sem_seg.shape[0] == 1:
- # 0 is foreground, ignore_index is background
- sem_seg = (sem_seg.squeeze(0) <= self.test_cfg.mask_thr).int()
- sem_seg[sem_seg == 1] = self.test_cfg.get('ignore_index', 255)
- else:
- # 0 is foreground, ignore_index is background
- if self.test_cfg.use_thr_for_mc:
- foreground_flag = sem_seg > self.test_cfg.mask_thr
- sem_seg = sem_seg.max(0)[1]
- sem_seg[foreground_flag.sum(0) == 0] = self.test_cfg.get(
- 'ignore_index', 255)
- else:
- sem_seg = sem_seg.max(0)[1]
- pred_sem_seg = PixelData(
- sem_seg=sem_seg[None],
- metainfo={
- 'label_names': text_prompts,
- 'ignore_index': self.test_cfg.get('ignore_index', 255)
- })
- return pred_sem_seg
- def _panoptic_inference(self, mask_cls, mask_pred, all_text_prompts,
- num_thing_class):
- scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
- mask_pred = mask_pred.sigmoid()
- keep = labels.ne(len(all_text_prompts)) & (
- scores > self.test_cfg.mask_thr)
- cur_scores = scores[keep]
- cur_classes = labels[keep]
- cur_masks = mask_pred[keep]
- cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
- h, w = cur_masks.shape[-2:]
- panoptic_seg = torch.full((h, w),
- self.test_cfg.get('ignore_index', 255),
- dtype=torch.int32,
- device=cur_masks.device)
- instance_id = 1
- if cur_masks.shape[0] > 0:
- cur_mask_ids = cur_prob_masks.argmax(0)
- for k in range(cur_classes.shape[0]):
- pred_class = cur_classes[k].item()
- isthing = int(pred_class) < num_thing_class
- mask_area = (cur_mask_ids == k).sum().item()
- original_area = (cur_masks[k] >= 0.5).sum().item()
- mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
- if mask_area > 0 and original_area > 0 and mask.sum().item(
- ) > 0:
- if mask_area / original_area < self.test_cfg.overlap_thr:
- continue
- # merge stuff regions
- if not isthing:
- panoptic_seg[mask] = int(pred_class)
- else:
- panoptic_seg[mask] = int(
- pred_class) + instance_id * INSTANCE_OFFSET
- instance_id += 1
- panoptic_seg = PixelData(
- sem_seg=panoptic_seg[None],
- metainfo={
- 'label_names': all_text_prompts,
- 'ignore_index': self.test_cfg.get('ignore_index', 255)
- })
- return panoptic_seg
|