# Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model.weight_init import caffe2_xavier_init, kaiming_init
from torch.nn import init

from mmdet.registry import MODELS


def _make_stack_3x3_convs(num_convs,
                          in_channels,
                          out_channels,
                          act_cfg=dict(type='ReLU', inplace=True)):
    convs = []
    for _ in range(num_convs):
        convs.append(nn.Conv2d(in_channels, out_channels, 3, padding=1))
        convs.append(MODELS.build(act_cfg))
        in_channels = out_channels
    return nn.Sequential(*convs)


class InstanceBranch(nn.Module):

    def __init__(self,
                 in_channels,
                 dim=256,
                 num_convs=4,
                 num_masks=100,
                 num_classes=80,
                 kernel_dim=128,
                 act_cfg=dict(type='ReLU', inplace=True)):
        super().__init__()
        num_masks = num_masks
        self.num_classes = num_classes

        self.inst_convs = _make_stack_3x3_convs(num_convs, in_channels, dim,
                                                act_cfg)
        # iam prediction, a simple conv
        self.iam_conv = nn.Conv2d(dim, num_masks, 3, padding=1)

        # outputs
        self.cls_score = nn.Linear(dim, self.num_classes)
        self.mask_kernel = nn.Linear(dim, kernel_dim)
        self.objectness = nn.Linear(dim, 1)

        self.prior_prob = 0.01
        self._init_weights()

    def _init_weights(self):
        for m in self.inst_convs.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
        bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
        for module in [self.iam_conv, self.cls_score]:
            init.constant_(module.bias, bias_value)
        init.normal_(self.iam_conv.weight, std=0.01)
        init.normal_(self.cls_score.weight, std=0.01)

        init.normal_(self.mask_kernel.weight, std=0.01)
        init.constant_(self.mask_kernel.bias, 0.0)

    def forward(self, features):
        # instance features (x4 convs)
        features = self.inst_convs(features)
        # predict instance activation maps
        iam = self.iam_conv(features)
        iam_prob = iam.sigmoid()

        B, N = iam_prob.shape[:2]
        C = features.size(1)
        # BxNxHxW -> BxNx(HW)
        iam_prob = iam_prob.view(B, N, -1)
        normalizer = iam_prob.sum(-1).clamp(min=1e-6)
        iam_prob = iam_prob / normalizer[:, :, None]
        # aggregate features: BxCxHxW -> Bx(HW)xC
        inst_features = torch.bmm(iam_prob,
                                  features.view(B, C, -1).permute(0, 2, 1))
        # predict classification & segmentation kernel & objectness
        pred_logits = self.cls_score(inst_features)
        pred_kernel = self.mask_kernel(inst_features)
        pred_scores = self.objectness(inst_features)
        return pred_logits, pred_kernel, pred_scores, iam


class MaskBranch(nn.Module):

    def __init__(self,
                 in_channels,
                 dim=256,
                 num_convs=4,
                 kernel_dim=128,
                 act_cfg=dict(type='ReLU', inplace=True)):
        super().__init__()
        self.mask_convs = _make_stack_3x3_convs(num_convs, in_channels, dim,
                                                act_cfg)
        self.projection = nn.Conv2d(dim, kernel_dim, kernel_size=1)
        self._init_weights()

    def _init_weights(self):
        for m in self.mask_convs.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
        kaiming_init(self.projection)

    def forward(self, features):
        # mask features (x4 convs)
        features = self.mask_convs(features)
        return self.projection(features)


@MODELS.register_module()
class BaseIAMDecoder(nn.Module):

    def __init__(self,
                 in_channels,
                 num_classes,
                 ins_dim=256,
                 ins_conv=4,
                 mask_dim=256,
                 mask_conv=4,
                 kernel_dim=128,
                 scale_factor=2.0,
                 output_iam=False,
                 num_masks=100,
                 act_cfg=dict(type='ReLU', inplace=True)):
        super().__init__()
        # add 2 for coordinates
        in_channels = in_channels  # ENCODER.NUM_CHANNELS + 2

        self.scale_factor = scale_factor
        self.output_iam = output_iam

        self.inst_branch = InstanceBranch(
            in_channels,
            dim=ins_dim,
            num_convs=ins_conv,
            num_masks=num_masks,
            num_classes=num_classes,
            kernel_dim=kernel_dim,
            act_cfg=act_cfg)
        self.mask_branch = MaskBranch(
            in_channels,
            dim=mask_dim,
            num_convs=mask_conv,
            kernel_dim=kernel_dim,
            act_cfg=act_cfg)

    @torch.no_grad()
    def compute_coordinates_linspace(self, x):
        # linspace is not supported in ONNX
        h, w = x.size(2), x.size(3)
        y_loc = torch.linspace(-1, 1, h, device=x.device)
        x_loc = torch.linspace(-1, 1, w, device=x.device)
        y_loc, x_loc = torch.meshgrid(y_loc, x_loc)
        y_loc = y_loc.expand([x.shape[0], 1, -1, -1])
        x_loc = x_loc.expand([x.shape[0], 1, -1, -1])
        locations = torch.cat([x_loc, y_loc], 1)
        return locations.to(x)

    @torch.no_grad()
    def compute_coordinates(self, x):
        h, w = x.size(2), x.size(3)
        y_loc = -1.0 + 2.0 * torch.arange(h, device=x.device) / (h - 1)
        x_loc = -1.0 + 2.0 * torch.arange(w, device=x.device) / (w - 1)
        y_loc, x_loc = torch.meshgrid(y_loc, x_loc)
        y_loc = y_loc.expand([x.shape[0], 1, -1, -1])
        x_loc = x_loc.expand([x.shape[0], 1, -1, -1])
        locations = torch.cat([x_loc, y_loc], 1)
        return locations.to(x)

    def forward(self, features):
        coord_features = self.compute_coordinates(features)
        features = torch.cat([coord_features, features], dim=1)
        pred_logits, pred_kernel, pred_scores, iam = self.inst_branch(features)
        mask_features = self.mask_branch(features)

        N = pred_kernel.shape[1]
        # mask_features: BxCxHxW
        B, C, H, W = mask_features.shape
        pred_masks = torch.bmm(pred_kernel,
                               mask_features.view(B, C,
                                                  H * W)).view(B, N, H, W)

        pred_masks = F.interpolate(
            pred_masks,
            scale_factor=self.scale_factor,
            mode='bilinear',
            align_corners=False)

        output = {
            'pred_logits': pred_logits,
            'pred_masks': pred_masks,
            'pred_scores': pred_scores,
        }

        if self.output_iam:
            iam = F.interpolate(
                iam,
                scale_factor=self.scale_factor,
                mode='bilinear',
                align_corners=False)
            output['pred_iam'] = iam

        return output


class GroupInstanceBranch(nn.Module):

    def __init__(self,
                 in_channels,
                 num_groups=4,
                 dim=256,
                 num_convs=4,
                 num_masks=100,
                 num_classes=80,
                 kernel_dim=128,
                 act_cfg=dict(type='ReLU', inplace=True)):
        super().__init__()
        self.num_groups = num_groups
        self.num_classes = num_classes

        self.inst_convs = _make_stack_3x3_convs(
            num_convs, in_channels, dim, act_cfg=act_cfg)
        # iam prediction, a group conv
        expand_dim = dim * self.num_groups
        self.iam_conv = nn.Conv2d(
            dim,
            num_masks * self.num_groups,
            3,
            padding=1,
            groups=self.num_groups)
        # outputs
        self.fc = nn.Linear(expand_dim, expand_dim)

        self.cls_score = nn.Linear(expand_dim, self.num_classes)
        self.mask_kernel = nn.Linear(expand_dim, kernel_dim)
        self.objectness = nn.Linear(expand_dim, 1)

        self.prior_prob = 0.01
        self._init_weights()

    def _init_weights(self):
        for m in self.inst_convs.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
        bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
        for module in [self.iam_conv, self.cls_score]:
            init.constant_(module.bias, bias_value)
        init.normal_(self.iam_conv.weight, std=0.01)
        init.normal_(self.cls_score.weight, std=0.01)

        init.normal_(self.mask_kernel.weight, std=0.01)
        init.constant_(self.mask_kernel.bias, 0.0)
        caffe2_xavier_init(self.fc)

    def forward(self, features):
        # instance features (x4 convs)
        features = self.inst_convs(features)
        # predict instance activation maps
        iam = self.iam_conv(features)
        iam_prob = iam.sigmoid()

        B, N = iam_prob.shape[:2]
        C = features.size(1)
        # BxNxHxW -> BxNx(HW)
        iam_prob = iam_prob.view(B, N, -1)
        normalizer = iam_prob.sum(-1).clamp(min=1e-6)
        iam_prob = iam_prob / normalizer[:, :, None]

        # aggregate features: BxCxHxW -> Bx(HW)xC
        inst_features = torch.bmm(iam_prob,
                                  features.view(B, C, -1).permute(0, 2, 1))

        inst_features = inst_features.reshape(B, 4, N // self.num_groups,
                                              -1).transpose(1, 2).reshape(
                                                  B, N // self.num_groups, -1)

        inst_features = F.relu_(self.fc(inst_features))
        # predict classification & segmentation kernel & objectness
        pred_logits = self.cls_score(inst_features)
        pred_kernel = self.mask_kernel(inst_features)
        pred_scores = self.objectness(inst_features)
        return pred_logits, pred_kernel, pred_scores, iam


@MODELS.register_module()
class GroupIAMDecoder(BaseIAMDecoder):

    def __init__(self,
                 in_channels,
                 num_classes,
                 num_groups=4,
                 ins_dim=256,
                 ins_conv=4,
                 mask_dim=256,
                 mask_conv=4,
                 kernel_dim=128,
                 scale_factor=2.0,
                 output_iam=False,
                 num_masks=100,
                 act_cfg=dict(type='ReLU', inplace=True)):
        super().__init__(
            in_channels=in_channels,
            num_classes=num_classes,
            ins_dim=ins_dim,
            ins_conv=ins_conv,
            mask_dim=mask_dim,
            mask_conv=mask_conv,
            kernel_dim=kernel_dim,
            scale_factor=scale_factor,
            output_iam=output_iam,
            num_masks=num_masks,
            act_cfg=act_cfg)
        self.inst_branch = GroupInstanceBranch(
            in_channels,
            num_groups=num_groups,
            dim=ins_dim,
            num_convs=ins_conv,
            num_masks=num_masks,
            num_classes=num_classes,
            kernel_dim=kernel_dim,
            act_cfg=act_cfg)


class GroupInstanceSoftBranch(GroupInstanceBranch):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.softmax_bias = nn.Parameter(torch.ones([
            1,
        ]))

    def forward(self, features):
        # instance features (x4 convs)
        features = self.inst_convs(features)
        # predict instance activation maps
        iam = self.iam_conv(features)

        B, N = iam.shape[:2]
        C = features.size(1)
        # BxNxHxW -> BxNx(HW)
        iam_prob = F.softmax(iam.view(B, N, -1) + self.softmax_bias, dim=-1)
        # aggregate features: BxCxHxW -> Bx(HW)xC
        inst_features = torch.bmm(iam_prob,
                                  features.view(B, C, -1).permute(0, 2, 1))

        inst_features = inst_features.reshape(B, self.num_groups,
                                              N // self.num_groups,
                                              -1).transpose(1, 2).reshape(
                                                  B, N // self.num_groups, -1)

        inst_features = F.relu_(self.fc(inst_features))
        # predict classification & segmentation kernel & objectness
        pred_logits = self.cls_score(inst_features)
        pred_kernel = self.mask_kernel(inst_features)
        pred_scores = self.objectness(inst_features)
        return pred_logits, pred_kernel, pred_scores, iam


@MODELS.register_module()
class GroupIAMSoftDecoder(BaseIAMDecoder):

    def __init__(self,
                 in_channels,
                 num_classes,
                 num_groups=4,
                 ins_dim=256,
                 ins_conv=4,
                 mask_dim=256,
                 mask_conv=4,
                 kernel_dim=128,
                 scale_factor=2.0,
                 output_iam=False,
                 num_masks=100,
                 act_cfg=dict(type='ReLU', inplace=True)):
        super().__init__(
            in_channels=in_channels,
            num_classes=num_classes,
            ins_dim=ins_dim,
            ins_conv=ins_conv,
            mask_dim=mask_dim,
            mask_conv=mask_conv,
            kernel_dim=kernel_dim,
            scale_factor=scale_factor,
            output_iam=output_iam,
            num_masks=num_masks,
            act_cfg=act_cfg)
        self.inst_branch = GroupInstanceSoftBranch(
            in_channels,
            num_groups=num_groups,
            dim=ins_dim,
            num_convs=ins_conv,
            num_masks=num_masks,
            num_classes=num_classes,
            kernel_dim=kernel_dim,
            act_cfg=act_cfg)