import numpy as np
import math
from typing import  Tuple

import torch
import torch.nn.init as init
import torch.nn.functional as F
from mmcv.cnn import ConvModule,build_activation_layer

from mmengine.model import BaseModule,ModuleList,Sequential
from torch import Tensor, nn

from mmdet.registry import MODELS
from mmdet.utils import OptConfigType,ConfigType
from ..layers import RtdetrDecoder,CdnQueryGenerator
from .dino_utils import get_contrastive_denoising_training_group

def _bias_initial_with_prob(prob):
    bias_init=float(-np.log((1-prob)/(prob)))
    return bias_init
@torch.no_grad()
def _linear_init(module:nn.Module)->None:
    bound = 1 / math.sqrt(module.weight.shape[0])
    init.uniform_(module.weight,-bound,bound)
    if hasattr(module, "bias") and module.bias is not None:
        init.uniform(module.bias,-bound,bound)
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act=dict(type='ReLU')):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
        self.act = nn.Identity() if act is None else build_activation_layer(act)

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x
@MODELS.register_module()
class RtDetrTransformer(BaseModule):

    def _build_input_proj_layer(self, backbone_feat_channels):
        self.input_proj = ModuleList()
        for in_channels in backbone_feat_channels:
            self.input_proj.append(
                ConvModule(
                    in_channels,
                    self.hidden_dim,
                    kernel_size=1,
                    bias=False,
                    norm_cfg=dict(type='BN', requires_grad=True),
                    act_cfg=None))
        in_channels=backbone_feat_channels[-1]
        for _ in range(self.num_levels - len(backbone_feat_channels)):
            self.input_proj.append(
                ConvModule(
                    in_channels,
                    self.hidden_dim,
                    3,
                    2,
                    1,
                    bias=False,
                    norm_cfg=dict(type='BN', requires_grad=True),
                    act_cfg=None
                )
            )
            in_channels=self.hidden_dim
    def _generate_anchors(self,spatial_shapes:list=None,grid_size:float=0.05,dtype=torch.float32,device='cpu'):
        if spatial_shapes is None:
            spatial_shapes=[[int(self.eval_spatial_size[0] / s), int(self.eval_spatial_size[1] / s)] for s in self.feat_strides]
        anchors=[]
        # print(spatial_shapes)
        for lvl, (h, w) in enumerate(spatial_shapes):
            grid_y, grid_x = torch.meshgrid(\
                torch.arange(end=h, dtype=dtype), \
                torch.arange(end=w, dtype=dtype), indexing='ij')
            grid_xy = torch.stack([grid_x, grid_y], -1)
            valid_WH = torch.tensor([w, h]).to(dtype)
            grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
            wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl)
            anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, h * w, 4))
        anchors = torch.concat(anchors, 1).to(device)
        valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
        # print(f'anchors size is {anchors.size()} and valid mask size is {valid_mask.size()}')
        anchors = torch.log(anchors / (1 - anchors))
        anchors = torch.where(valid_mask, anchors, torch.inf)
        return anchors, valid_mask
    def _reset_parameter(self):
        bias_cls=_bias_initial_with_prob(0.01)
        _linear_init(self.encoder_score_head)
        init.constant_(self.encoder_score_head.bias,bias_cls)
        init.constant_(self.encoder_bbox_head.layers[-1].weight,0.)
        init.constant_(self.encoder_bbox_head.layers[-1].bias,0.)
        for cls_,reg_ in zip(self.decoder_score_head,self.decoder_bbox_head):
            _linear_init(cls_)
            init.constant_(cls_.bias,bias_cls)
            init.constant_(reg_.layers[-1].weight,0.)
            init.constant_(reg_.layers[-1].bias,0.)
        _linear_init(self.encoder_output[0])
        init.xavier_uniform_(self.encoder_output[0].weight)
        if self.learnt_init_query:
            init.xavier_uniform_(self.tgt_embed.weight)
        init.xavier_uniform_(self.query_pos_head.layers[0].weight)
        init.xavier_uniform_(self.query_pos_head.layers[1].weight)
        # for l in self.input_proj:
        #     init.xavier_uniform_(l.weight)
    def _get_encoder_input(self,feats:Tensor)->Tuple[Tensor]:
        proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
        if self.num_levels > len(proj_feats):
            len_srcs = len(proj_feats)
            for i in range(len_srcs, self.num_levels):
                if i == len_srcs:
                    proj_feats.append(self.input_proj[i](feats[-1]))
                else:
                    proj_feats.append(self.input_proj[i](proj_feats[-1]))
        #get input for encoder
        feat_flatten = []
        spatial_shapes = []
        for feat in proj_feats:
            spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device)
            # [b, c, h, w] -> [b, h*w, c]
            feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
            # [num_levels, 2] each level
            spatial_shapes.append(spatial_shape)
        spatial_shapes = torch.cat(spatial_shapes).view(-1, 2)
        level_start_index = torch.cat((
            spatial_shapes.new_zeros((1, )),  # (num_level)
            spatial_shapes.prod(1).cumsum(0)[:-1]))
        # [b, l, c]
        feat_flatten = torch.concat(feat_flatten, 1)
        return (feat_flatten, spatial_shapes, level_start_index)
    def _get_decoder_input(self,memory:Tensor,
                           spatial_shapes,
                           denoising_class=None,
                           denoising_bbox_unact=None):
        bs, _, _ = memory.shape
        # print(memory.size())
        #prepare input for decoder
        if self.training or self.eval_size is None:
            anchors, valid_mask = self._generate_anchors(spatial_shapes,device=memory.device)
        else:
            anchors, valid_mask = self.anchors.to(memory.device), self.valid_mask.to(memory.device)
        memory = valid_mask.to(memory.dtype) * memory
        output_memory = self.encoder_output(memory)
        enc_outputs_class = self.encoder_score_head(output_memory)
        enc_outputs_coord_unact = self.encoder_bbox_head(output_memory) + anchors
        topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)[1]
        # extract region proposal boxes
        reference_points_unact = enc_outputs_coord_unact.gather(dim=1, \
            index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1]))
        enc_topk_bboxes = F.sigmoid(reference_points_unact)
        if denoising_bbox_unact is not None:
            reference_points_unact = torch.concat(
                [denoising_bbox_unact, reference_points_unact], 1)
        enc_topk_logits = enc_outputs_class.gather(dim=1, \
        index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]))
        # extract region features
        if self.learnt_init_query:
            target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1])
        else:
            target = output_memory.gather(dim=1, \
                index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
            target = target.detach()
        if denoising_class is not None:
            target = torch.concat([denoising_class, target], 1)
        return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
    def __init__(self,
                 num_classes:int,
                 hidden_dim:int,
                 num_queries:int,
                 position_type:str='sine',
                 feat_channels:list=[256,256,256],
                 feat_strides:list=[8,16,24],
                 num_levels:int=3,
                 num_crossattn_points:int=4,
                 number_head:int=8,
                 number_decoder_layer:int=6,
                 dim_feedforward_ratio:int=4,
                 dropout:float=0.0,
                 act_cfg:OptConfigType = dict(type='ReLU', inplace=True),
                 num_denoising:int=100,
                 label_noise_ratio:float=0.5,
                 box_noise_scale:float=1.0,
                 learnt_init_query:bool=True,
                 eval_size:list=None,
                 eval_spatial_size:list=None,
                 eval_idx:int=-1,
                 eps:float=1e-2
                 ):
        super().__init__()
        assert position_type in ['sine', 'learned'], \
            f'ValueError: position_embed_type not supported {position_type}!'
        assert len(feat_channels) <= num_levels
        assert len(feat_strides) == len(feat_channels)
        for _ in range(num_levels - len(feat_strides)):
            feat_strides.append(feat_strides[-1] * 2)
        self.hidden_dim = hidden_dim
        self.number_head = number_head
        self.feat_strides = feat_strides
        self.num_levels = num_levels
        self.num_classes = num_classes
        self.num_queries = num_queries
        self.eps = eps
        self.number_decoder_layer = number_decoder_layer
        self.eval_size = eval_size
        self.num_denoising=num_denoising
        self.label_noise_ratio=label_noise_ratio
        self.box_noise_scale=box_noise_scale
        self.eval_idx=eval_idx
        self.eval_spatial_size=eval_spatial_size
        #backbone feature projection
        self._build_input_proj_layer(feat_channels)

        #Transformer module
        # embed_dims,
        # num_heads,
        # attn_drop=0.,
        # proj_drop=0.,
        # dropout_layer=dict(type='Dropout', drop_prob=0.),
        # init_cfg=None,
        # batch_first=False,
        # **kwargs
        self_attn_cfg=dict(embed_dims=hidden_dim,num_heads=number_head,
                           attn_drop=dropout,proj_drop=dropout,
                           batch_first=True)
        
        # embed_dims: int = 256,
        # num_heads: int = 8,
        # num_levels: int = 4,
        # num_points: int = 4,
        # im2col_step: int = 64,
        # dropout: float = 0.1,
        # batch_first: bool = False,
        # norm_cfg: Optional[dict] = None,
        # init_cfg: Optional[mmengine.ConfigDict] = None,
        # value_proj_ratio: float = 1.0
        cross_attn_cfg=dict(embed_dims=hidden_dim,num_heads=number_head,
                            num_levels=num_levels,num_points=num_crossattn_points,
                            dropout=dropout,batch_first=True)
        # embed_dims=256,
        # feedforward_channels=1024,
        # num_fcs=2,
        # act_cfg=dict(type='ReLU', inplace=True),
        # ffn_drop=0.,
        # dropout_layer=None,
        # add_identity=True,
        # init_cfg=None,
        # layer_scale_init_value=0.
        ffn_cfg=dict(embed_dims=hidden_dim,feedforward_channels=hidden_dim*dim_feedforward_ratio,
                     num_fcs=2,ffn_drop=0,
                     act_cfg=act_cfg)
        decode_layer_cfg=dict(self_attn_cfg=self_attn_cfg,cross_attn_cfg=cross_attn_cfg,ffn_cfg=ffn_cfg)
        self.decoder=RtdetrDecoder(num_layers=number_decoder_layer,layer_cfg=decode_layer_cfg)
        #denoising part
        # def __init__(self,
        #          num_classes: int,
        #          embed_dims: int,
        #          num_matching_queries: int,
        #          label_noise_scale: float = 0.5,
        #          box_noise_scale: float = 1.0,
        #          group_cfg: OptConfigType = None) -> None:
        if num_denoising>0:
            self.dino=CdnQueryGenerator(
                         num_classes=num_classes,
                         embed_dims=hidden_dim,
                         num_matching_queries=num_queries,
                         label_noise_scale=label_noise_ratio,
                         box_noise_scale=box_noise_scale,
                         group_cfg=dict(dynamic=True, num_groups=None,num_dn_queries=num_denoising))
        
        #decoder embedding
        self.learnt_init_query = learnt_init_query
        if learnt_init_query:
            self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
        
        self.query_pos_head = MLP(4,hidden_dim=hidden_dim,output_dim=hidden_dim,num_layers=2)

        #encoder head in transformer
        self.encoder_output=Sequential(nn.Linear(hidden_dim,hidden_dim),nn.LayerNorm(hidden_dim))
        self.encoder_score_head = nn.Linear(hidden_dim, num_classes)
        self.encoder_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)

        #decoder head in transformer
        self.decoder_score_head=ModuleList(nn.Linear(hidden_dim,num_classes) for _ in range(number_decoder_layer))
        self.decoder_bbox_head=ModuleList(MLP(hidden_dim,hidden_dim,4,num_layers=3) for _ in range(number_decoder_layer))
        #reset parametre for encoder_head and decoder_head with xavier uniform
        if self.eval_spatial_size:
            self.anchors, self.valid_mask = self._generate_anchors()
            # print(f'anchors size is {self.anchors.size()} and valid mask size is {self.valid_mask.size()}')
        self._reset_parameter()
    def forward(self,feats,pad_mask=None,gt_meta=None):
        (memory, spatial_shapes,level_start_index) = self._get_encoder_input(feats)
        if self.training and self.num_denoising:
            denoising_class, denoising_bbox_unact, attn_mask, dn_meta=self.dino.__call__(batch_data_samples=gt_meta)
        else:
            denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
        target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
            self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact)
    #     def forward(self,
    #             target:Tensor,
    #             memory:Tensor,
    #             memory_spatial_shapes:Tensor,
    #             memory_level_start_index:Tensor,
    #             ref_points_unact:Tensor,
    #             query_pos_head:FFN,
    #             bbox_head:ModuleList,
    #             score_head:ModuleList,
    #             attn_mask:Tensor=None,
    # )->Tuple[Tensor]:
    #cls and bbox for query and reference
        query, reference=self.decoder(target=target,
                                            memory=memory,
                                            memory_spatial_shapes=spatial_shapes,
                                            memory_level_start_index=level_start_index,
                                            ref_points_unact=init_ref_points_unact,
                                            query_pos_head=self.query_pos_head,
                                            bbox_head=self.decoder_bbox_head,
                                            score_head=self.decoder_score_head,
                                            attn_mask=attn_mask)
        return (query, reference, enc_topk_bboxes, enc_topk_logits,
                dn_meta)