import copy from typing import Sequence import torch from mmengine.structures import InstanceData, PixelData from torch import nn from torch.nn import functional as F from mmdet.evaluation.functional import INSTANCE_OFFSET from mmdet.registry import MODELS from .utils import (is_lower_torch_version, retry_if_cuda_oom, sem_seg_postprocess) @MODELS.register_module() class XDecoderUnifiedhead(nn.Module): def __init__(self, in_channels: int, pixel_decoder: nn.Module, transformer_decoder: nn.Module, task: str = 'semseg', test_cfg=None): super().__init__() self.task = task self.test_cfg = test_cfg pixel_decoder_ = copy.deepcopy(pixel_decoder) pixel_decoder_.update(in_channels=in_channels) self.pixel_decoder = MODELS.build(pixel_decoder_) transformer_decoder_ = copy.deepcopy(transformer_decoder) transformer_decoder_.update(task=task) self.predictor = MODELS.build(transformer_decoder_) self.return_inter_mask = False if self.task == 'ref-caption': # ref-caption = ref-seg + caption, # so we need to return the intermediate mask self.return_inter_mask = True self._all_text_prompts = None self._extra = None # TODO: Very trick, for retrieval task self._force_not_use_cache = False def pre_process(self, batch_data_samples, device): extra = {} if self.task != 'caption': # have text all_text_prompts = [] num_thing_class = 0 for data_samples in batch_data_samples: if isinstance(data_samples.text, str): text = data_samples.text.split('.') elif isinstance(data_samples.text, Sequence): text = data_samples.text else: raise TypeError( 'Type pf data_sample.text must be sequence or str') text = list(filter(lambda x: len(x) > 0, text)) all_text_prompts.append(text) num_thing_class = len(text) # for panoptic if 'stuff_text' in data_samples: if isinstance(data_samples.stuff_text, str): text = data_samples.stuff_text.split('.') elif isinstance(data_samples.stuff_text, Sequence): text = data_samples.stuff_text else: raise TypeError('Type pf data_sample.stuff_text ' 'must be sequence or str') text = list(filter(lambda x: len(x) > 0, text)) all_text_prompts[-1].extend(text) # TODO: support batch all_text_prompts = all_text_prompts[0] if all_text_prompts != self._all_text_prompts \ or self._force_not_use_cache: # avoid redundant computation self._all_text_prompts = all_text_prompts if self.task in ['semseg', 'instance', 'panoptic']: self.predictor.lang_encoder.get_mean_embeds( all_text_prompts + ['background']) elif self.task == 'ref-seg': token_info = self.predictor.lang_encoder.get_text_embeds( all_text_prompts, norm=False) token_emb = token_info['token_emb'] tokens = token_info['tokens'] query_emb = token_emb[tokens['attention_mask'].bool()] extra['grounding_tokens'] = query_emb[:, None] extra['class_emb'] = token_info['class_emb'] elif self.task == 'retrieval': token_info = self.predictor.lang_encoder.get_text_embeds( all_text_prompts, norm=True) extra['class_emb'] = token_info['class_emb'] self._extra = extra return extra, all_text_prompts, num_thing_class else: return self._extra, all_text_prompts, num_thing_class else: if not hasattr(self, 'start_token'): self.start_token = self.predictor.lang_encoder. \ get_sot_token(device=device) extra['start_token'] = self.start_token return extra, None, None def predict(self, features, batch_data_samples): # multi scale feature mask_features, multi_scale_features = self.pixel_decoder(features) # pre process extra, all_text_prompts, num_thing_class = self.pre_process( batch_data_samples, mask_features.device) # transformer decoder forward predictions = self.predictor( multi_scale_features, mask_features, extra=extra) # post process return self.post_process(predictions, batch_data_samples, all_text_prompts, num_thing_class) def post_process(self, predictions, batch_data_samples, all_text_prompts, num_thing_class): batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] batch_input_shape = batch_data_samples[0].metainfo['batch_input_shape'] if self.task == 'caption': for text, data_samples in zip(predictions['pred_caption'], batch_data_samples): data_samples.pred_caption = text if 'pred_instances' in batch_data_samples[0]: for img_metas, data_samples in zip(batch_img_metas, batch_data_samples): original_caption = data_samples.text.split('.') text_prompts = list( filter(lambda x: len(x) > 0, original_caption)) height = img_metas['ori_shape'][0] width = img_metas['ori_shape'][1] image_size = img_metas['grounding_img_shape'][:2] mask_pred_result = data_samples.pred_instances.masks.float( ) mask_cls_result = data_samples.pred_instances.scores.float( ) mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, image_size, height, width) pred_instances = retry_if_cuda_oom( self._instance_inference)(mask_cls_result, mask_pred_result, text_prompts) data_samples.pred_instances = pred_instances elif self.task in ['semseg', 'instance', 'panoptic']: mask_pred_results = predictions['pred_masks'] mask_cls_results = predictions['pred_logits'] if is_lower_torch_version(): mask_pred_results = F.interpolate( mask_pred_results, size=(batch_input_shape[-2], batch_input_shape[-1]), mode='bicubic', align_corners=False) else: mask_pred_results = F.interpolate( mask_pred_results, size=(batch_input_shape[-2], batch_input_shape[-1]), mode='bicubic', align_corners=False, antialias=True) # for batch for mask_cls_result, \ mask_pred_result, \ img_metas, \ data_samples in zip( mask_cls_results, mask_pred_results, batch_img_metas, batch_data_samples): height = img_metas['ori_shape'][0] width = img_metas['ori_shape'][1] image_size = img_metas['img_shape'][:2] mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, image_size, height, width) mask_cls_result = mask_cls_result.to(mask_pred_result) if self.task == 'semseg': pred_sem_seg = retry_if_cuda_oom(self._semantic_inference)( mask_cls_result, mask_pred_result, all_text_prompts) data_samples.pred_sem_seg = pred_sem_seg elif self.task == 'instance': pred_instances = retry_if_cuda_oom( self._instance_inference)(mask_cls_result, mask_pred_result, all_text_prompts) data_samples.pred_instances = pred_instances elif self.task == 'panoptic': pred_panoptic_seg = retry_if_cuda_oom( self._panoptic_inference)(mask_cls_result, mask_pred_result, all_text_prompts, num_thing_class) data_samples.pred_panoptic_seg = pred_panoptic_seg elif self.task == 'ref-seg': mask_pred_results = predictions['pred_masks'] mask_cls_results = predictions['pred_logits'] results_ = zip(mask_pred_results, mask_cls_results, batch_img_metas, batch_data_samples) for mask_pred_result, mask_cls_result, \ img_metas, data_samples in results_: if is_lower_torch_version(): mask_pred_result = F.interpolate( mask_pred_result[None], size=(batch_input_shape[-2], batch_input_shape[-1]), mode='bicubic', align_corners=False)[0] else: mask_pred_result = F.interpolate( mask_pred_result[None], size=(batch_input_shape[-2], batch_input_shape[-1]), mode='bicubic', align_corners=False, antialias=True)[0] if self.return_inter_mask: mask = mask_pred_result > 0 pred_instances = InstanceData() pred_instances.masks = mask pred_instances.scores = mask_cls_result data_samples.pred_instances = pred_instances continue height = img_metas['ori_shape'][0] width = img_metas['ori_shape'][1] image_size = img_metas['img_shape'][:2] mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, image_size, height, width) pred_instances = retry_if_cuda_oom(self._instance_inference)( mask_cls_result, mask_pred_result, all_text_prompts) data_samples.pred_instances = pred_instances elif self.task == 'retrieval': batch_data_samples[0].pred_score = predictions['pred_logits'] return batch_data_samples def _instance_inference(self, mask_cls, mask_pred, text_prompts): num_class = len(text_prompts) if self.task in ['ref-seg', 'caption']: scores = F.softmax(mask_cls, dim=-1) scores_per_image = scores.max(dim=-1)[0] labels_per_image = torch.arange(num_class) else: scores = F.softmax(mask_cls, dim=-1)[:, :-1] labels = torch.arange( num_class, device=scores.device).unsqueeze(0).repeat(scores.shape[0], 1).flatten(0, 1) scores_per_image, topk_indices = scores.flatten(0, 1).topk( self.test_cfg.get('max_per_img', 100), sorted=False) labels_per_image = labels[topk_indices] topk_indices = (topk_indices // num_class) mask_pred = mask_pred[topk_indices] result = InstanceData() mask_pred = mask_pred.sigmoid() result.masks = (mask_pred > self.test_cfg.mask_thr).float() # calculate average mask prob mask_scores_per_image = (mask_pred.flatten(1) * result.masks.flatten(1)).sum(1) / ( result.masks.flatten(1).sum(1) + 1e-6) result.scores = scores_per_image * mask_scores_per_image result.labels = labels_per_image result.label_names = [ text_prompts[label] for label in labels_per_image ] result.bboxes = result.scores.new_zeros(len(result.scores), 4) return result def _semantic_inference(self, mask_cls, mask_pred, text_prompts): mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] mask_pred = mask_pred.sigmoid() sem_seg = torch.einsum('qc,qhw->chw', mask_cls, mask_pred) if sem_seg.shape[0] == 1: # 0 is foreground, ignore_index is background sem_seg = (sem_seg.squeeze(0) <= self.test_cfg.mask_thr).int() sem_seg[sem_seg == 1] = self.test_cfg.get('ignore_index', 255) else: # 0 is foreground, ignore_index is background if self.test_cfg.use_thr_for_mc: foreground_flag = sem_seg > self.test_cfg.mask_thr sem_seg = sem_seg.max(0)[1] sem_seg[foreground_flag.sum(0) == 0] = self.test_cfg.get( 'ignore_index', 255) else: sem_seg = sem_seg.max(0)[1] pred_sem_seg = PixelData( sem_seg=sem_seg[None], metainfo={ 'label_names': text_prompts, 'ignore_index': self.test_cfg.get('ignore_index', 255) }) return pred_sem_seg def _panoptic_inference(self, mask_cls, mask_pred, all_text_prompts, num_thing_class): scores, labels = F.softmax(mask_cls, dim=-1).max(-1) mask_pred = mask_pred.sigmoid() keep = labels.ne(len(all_text_prompts)) & ( scores > self.test_cfg.mask_thr) cur_scores = scores[keep] cur_classes = labels[keep] cur_masks = mask_pred[keep] cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks h, w = cur_masks.shape[-2:] panoptic_seg = torch.full((h, w), self.test_cfg.get('ignore_index', 255), dtype=torch.int32, device=cur_masks.device) instance_id = 1 if cur_masks.shape[0] > 0: cur_mask_ids = cur_prob_masks.argmax(0) for k in range(cur_classes.shape[0]): pred_class = cur_classes[k].item() isthing = int(pred_class) < num_thing_class mask_area = (cur_mask_ids == k).sum().item() original_area = (cur_masks[k] >= 0.5).sum().item() mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5) if mask_area > 0 and original_area > 0 and mask.sum().item( ) > 0: if mask_area / original_area < self.test_cfg.overlap_thr: continue # merge stuff regions if not isthing: panoptic_seg[mask] = int(pred_class) else: panoptic_seg[mask] = int( pred_class) + instance_id * INSTANCE_OFFSET instance_id += 1 panoptic_seg = PixelData( sem_seg=panoptic_seg[None], metainfo={ 'label_names': all_text_prompts, 'ignore_index': self.test_cfg.get('ignore_index', 255) }) return panoptic_seg