import argparse

import numpy as np
import torch
from tensorflow.python.training import py_checkpoint_reader

torch.set_printoptions(precision=20)


def tf2pth(v):
    if v.ndim == 4:
        return np.ascontiguousarray(v.transpose(3, 2, 0, 1))
    elif v.ndim == 2:
        return np.ascontiguousarray(v.transpose())
    return v


def convert_key(model_name, bifpn_repeats, weights):

    p6_w1 = [
        torch.tensor([-1e4, -1e4], dtype=torch.float64)
        for _ in range(bifpn_repeats)
    ]
    p5_w1 = [
        torch.tensor([-1e4, -1e4], dtype=torch.float64)
        for _ in range(bifpn_repeats)
    ]
    p4_w1 = [
        torch.tensor([-1e4, -1e4], dtype=torch.float64)
        for _ in range(bifpn_repeats)
    ]
    p3_w1 = [
        torch.tensor([-1e4, -1e4], dtype=torch.float64)
        for _ in range(bifpn_repeats)
    ]
    p4_w2 = [
        torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
        for _ in range(bifpn_repeats)
    ]
    p5_w2 = [
        torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
        for _ in range(bifpn_repeats)
    ]
    p6_w2 = [
        torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
        for _ in range(bifpn_repeats)
    ]
    p7_w2 = [
        torch.tensor([-1e4, -1e4], dtype=torch.float64)
        for _ in range(bifpn_repeats)
    ]
    idx2key = {
        0: '1.0',
        1: '2.0',
        2: '2.1',
        3: '3.0',
        4: '3.1',
        5: '4.0',
        6: '4.1',
        7: '4.2',
        8: '4.3',
        9: '4.4',
        10: '4.5',
        11: '5.0',
        12: '5.1',
        13: '5.2',
        14: '5.3',
        15: '5.4'
    }
    m = dict()
    for k, v in weights.items():

        if 'Exponential' in k or 'global_step' in k:
            continue

        seg = k.split('/')
        if len(seg) == 1:
            continue
        if seg[2] == 'depthwise_conv2d':
            v = v.transpose(1, 0)

        if seg[0] == model_name:
            if seg[1] == 'stem':
                prefix = 'backbone.layers.0'
                mapping = {
                    'conv2d/kernel': 'conv.weight',
                    'tpu_batch_normalization/beta': 'bn.bias',
                    'tpu_batch_normalization/gamma': 'bn.weight',
                    'tpu_batch_normalization/moving_mean': 'bn.running_mean',
                    'tpu_batch_normalization/moving_variance':
                    'bn.running_var',
                }
                suffix = mapping['/'.join(seg[2:])]
                m[prefix + '.' + suffix] = v

            elif seg[1].startswith('blocks_'):
                idx = int(seg[1][7:])
                prefix = '.'.join(['backbone', 'layers', idx2key[idx]])
                base_mapping = {
                    'depthwise_conv2d/depthwise_kernel':
                    'depthwise_conv.conv.weight',
                    'se/conv2d/kernel': 'se.conv1.conv.weight',
                    'se/conv2d/bias': 'se.conv1.conv.bias',
                    'se/conv2d_1/kernel': 'se.conv2.conv.weight',
                    'se/conv2d_1/bias': 'se.conv2.conv.bias'
                }
                if idx == 0:
                    mapping = {
                        'conv2d/kernel':
                        'linear_conv.conv.weight',
                        'tpu_batch_normalization/beta':
                        'depthwise_conv.bn.bias',
                        'tpu_batch_normalization/gamma':
                        'depthwise_conv.bn.weight',
                        'tpu_batch_normalization/moving_mean':
                        'depthwise_conv.bn.running_mean',
                        'tpu_batch_normalization/moving_variance':
                        'depthwise_conv.bn.running_var',
                        'tpu_batch_normalization_1/beta':
                        'linear_conv.bn.bias',
                        'tpu_batch_normalization_1/gamma':
                        'linear_conv.bn.weight',
                        'tpu_batch_normalization_1/moving_mean':
                        'linear_conv.bn.running_mean',
                        'tpu_batch_normalization_1/moving_variance':
                        'linear_conv.bn.running_var',
                    }
                else:
                    mapping = {
                        'depthwise_conv2d/depthwise_kernel':
                        'depthwise_conv.conv.weight',
                        'conv2d/kernel':
                        'expand_conv.conv.weight',
                        'conv2d_1/kernel':
                        'linear_conv.conv.weight',
                        'tpu_batch_normalization/beta':
                        'expand_conv.bn.bias',
                        'tpu_batch_normalization/gamma':
                        'expand_conv.bn.weight',
                        'tpu_batch_normalization/moving_mean':
                        'expand_conv.bn.running_mean',
                        'tpu_batch_normalization/moving_variance':
                        'expand_conv.bn.running_var',
                        'tpu_batch_normalization_1/beta':
                        'depthwise_conv.bn.bias',
                        'tpu_batch_normalization_1/gamma':
                        'depthwise_conv.bn.weight',
                        'tpu_batch_normalization_1/moving_mean':
                        'depthwise_conv.bn.running_mean',
                        'tpu_batch_normalization_1/moving_variance':
                        'depthwise_conv.bn.running_var',
                        'tpu_batch_normalization_2/beta':
                        'linear_conv.bn.bias',
                        'tpu_batch_normalization_2/gamma':
                        'linear_conv.bn.weight',
                        'tpu_batch_normalization_2/moving_mean':
                        'linear_conv.bn.running_mean',
                        'tpu_batch_normalization_2/moving_variance':
                        'linear_conv.bn.running_var',
                    }
                mapping.update(base_mapping)
                suffix = mapping['/'.join(seg[2:])]
                m[prefix + '.' + suffix] = v
        elif seg[0] == 'resample_p6':
            prefix = 'neck.bifpn.0.p5_to_p6.0'
            mapping = {
                'conv2d/kernel': 'down_conv.weight',
                'conv2d/bias': 'down_conv.bias',
                'bn/beta': 'bn.bias',
                'bn/gamma': 'bn.weight',
                'bn/moving_mean': 'bn.running_mean',
                'bn/moving_variance': 'bn.running_var',
            }
            suffix = mapping['/'.join(seg[1:])]
            m[prefix + '.' + suffix] = v
        elif seg[0] == 'fpn_cells':
            fpn_idx = int(seg[1][5:])
            prefix = '.'.join(['neck', 'bifpn', str(fpn_idx)])
            fnode_id = int(seg[2][5])
            if fnode_id == 0:
                mapping = {
                    'op_after_combine5/conv/depthwise_kernel':
                    'conv6_up.depthwise_conv.weight',
                    'op_after_combine5/conv/pointwise_kernel':
                    'conv6_up.pointwise_conv.weight',
                    'op_after_combine5/conv/bias':
                    'conv6_up.pointwise_conv.bias',
                    'op_after_combine5/bn/beta':
                    'conv6_up.bn.bias',
                    'op_after_combine5/bn/gamma':
                    'conv6_up.bn.weight',
                    'op_after_combine5/bn/moving_mean':
                    'conv6_up.bn.running_mean',
                    'op_after_combine5/bn/moving_variance':
                    'conv6_up.bn.running_var',
                }
                if seg[3] != 'WSM' and seg[3] != 'WSM_1':
                    suffix = mapping['/'.join(seg[3:])]
                    if 'depthwise_conv' in suffix:
                        v = v.transpose(1, 0)
                    m[prefix + '.' + suffix] = v
                elif seg[3] == 'WSM':
                    p6_w1[fpn_idx][0] = v
                elif seg[3] == 'WSM_1':
                    p6_w1[fpn_idx][1] = v
                if torch.min(p6_w1[fpn_idx]) > -1e4:
                    m[prefix + '.p6_w1'] = p6_w1[fpn_idx]
            elif fnode_id == 1:
                base_mapping = {
                    'op_after_combine6/conv/depthwise_kernel':
                    'conv5_up.depthwise_conv.weight',
                    'op_after_combine6/conv/pointwise_kernel':
                    'conv5_up.pointwise_conv.weight',
                    'op_after_combine6/conv/bias':
                    'conv5_up.pointwise_conv.bias',
                    'op_after_combine6/bn/beta':
                    'conv5_up.bn.bias',
                    'op_after_combine6/bn/gamma':
                    'conv5_up.bn.weight',
                    'op_after_combine6/bn/moving_mean':
                    'conv5_up.bn.running_mean',
                    'op_after_combine6/bn/moving_variance':
                    'conv5_up.bn.running_var',
                }
                if fpn_idx == 0:
                    mapping = {
                        'resample_0_2_6/conv2d/kernel':
                        'p5_down_channel.down_conv.weight',
                        'resample_0_2_6/conv2d/bias':
                        'p5_down_channel.down_conv.bias',
                        'resample_0_2_6/bn/beta':
                        'p5_down_channel.bn.bias',
                        'resample_0_2_6/bn/gamma':
                        'p5_down_channel.bn.weight',
                        'resample_0_2_6/bn/moving_mean':
                        'p5_down_channel.bn.running_mean',
                        'resample_0_2_6/bn/moving_variance':
                        'p5_down_channel.bn.running_var',
                    }
                    base_mapping.update(mapping)
                if seg[3] != 'WSM' and seg[3] != 'WSM_1':
                    suffix = base_mapping['/'.join(seg[3:])]
                    if 'depthwise_conv' in suffix:
                        v = v.transpose(1, 0)
                    m[prefix + '.' + suffix] = v
                elif seg[3] == 'WSM':
                    p5_w1[fpn_idx][0] = v
                elif seg[3] == 'WSM_1':
                    p5_w1[fpn_idx][1] = v
                if torch.min(p5_w1[fpn_idx]) > -1e4:
                    m[prefix + '.p5_w1'] = p5_w1[fpn_idx]
            elif fnode_id == 2:
                base_mapping = {
                    'op_after_combine7/conv/depthwise_kernel':
                    'conv4_up.depthwise_conv.weight',
                    'op_after_combine7/conv/pointwise_kernel':
                    'conv4_up.pointwise_conv.weight',
                    'op_after_combine7/conv/bias':
                    'conv4_up.pointwise_conv.bias',
                    'op_after_combine7/bn/beta':
                    'conv4_up.bn.bias',
                    'op_after_combine7/bn/gamma':
                    'conv4_up.bn.weight',
                    'op_after_combine7/bn/moving_mean':
                    'conv4_up.bn.running_mean',
                    'op_after_combine7/bn/moving_variance':
                    'conv4_up.bn.running_var',
                }
                if fpn_idx == 0:
                    mapping = {
                        'resample_0_1_7/conv2d/kernel':
                        'p4_down_channel.down_conv.weight',
                        'resample_0_1_7/conv2d/bias':
                        'p4_down_channel.down_conv.bias',
                        'resample_0_1_7/bn/beta':
                        'p4_down_channel.bn.bias',
                        'resample_0_1_7/bn/gamma':
                        'p4_down_channel.bn.weight',
                        'resample_0_1_7/bn/moving_mean':
                        'p4_down_channel.bn.running_mean',
                        'resample_0_1_7/bn/moving_variance':
                        'p4_down_channel.bn.running_var',
                    }
                    base_mapping.update(mapping)
                if seg[3] != 'WSM' and seg[3] != 'WSM_1':
                    suffix = base_mapping['/'.join(seg[3:])]
                    if 'depthwise_conv' in suffix:
                        v = v.transpose(1, 0)
                    m[prefix + '.' + suffix] = v
                elif seg[3] == 'WSM':
                    p4_w1[fpn_idx][0] = v
                elif seg[3] == 'WSM_1':
                    p4_w1[fpn_idx][1] = v
                if torch.min(p4_w1[fpn_idx]) > -1e4:
                    m[prefix + '.p4_w1'] = p4_w1[fpn_idx]
            elif fnode_id == 3:

                base_mapping = {
                    'op_after_combine8/conv/depthwise_kernel':
                    'conv3_up.depthwise_conv.weight',
                    'op_after_combine8/conv/pointwise_kernel':
                    'conv3_up.pointwise_conv.weight',
                    'op_after_combine8/conv/bias':
                    'conv3_up.pointwise_conv.bias',
                    'op_after_combine8/bn/beta':
                    'conv3_up.bn.bias',
                    'op_after_combine8/bn/gamma':
                    'conv3_up.bn.weight',
                    'op_after_combine8/bn/moving_mean':
                    'conv3_up.bn.running_mean',
                    'op_after_combine8/bn/moving_variance':
                    'conv3_up.bn.running_var',
                }
                if fpn_idx == 0:
                    mapping = {
                        'resample_0_0_8/conv2d/kernel':
                        'p3_down_channel.down_conv.weight',
                        'resample_0_0_8/conv2d/bias':
                        'p3_down_channel.down_conv.bias',
                        'resample_0_0_8/bn/beta':
                        'p3_down_channel.bn.bias',
                        'resample_0_0_8/bn/gamma':
                        'p3_down_channel.bn.weight',
                        'resample_0_0_8/bn/moving_mean':
                        'p3_down_channel.bn.running_mean',
                        'resample_0_0_8/bn/moving_variance':
                        'p3_down_channel.bn.running_var',
                    }
                    base_mapping.update(mapping)
                if seg[3] != 'WSM' and seg[3] != 'WSM_1':
                    suffix = base_mapping['/'.join(seg[3:])]
                    if 'depthwise_conv' in suffix:
                        v = v.transpose(1, 0)
                    m[prefix + '.' + suffix] = v
                elif seg[3] == 'WSM':
                    p3_w1[fpn_idx][0] = v
                elif seg[3] == 'WSM_1':
                    p3_w1[fpn_idx][1] = v
                if torch.min(p3_w1[fpn_idx]) > -1e4:
                    m[prefix + '.p3_w1'] = p3_w1[fpn_idx]
            elif fnode_id == 4:
                base_mapping = {
                    'op_after_combine9/conv/depthwise_kernel':
                    'conv4_down.depthwise_conv.weight',
                    'op_after_combine9/conv/pointwise_kernel':
                    'conv4_down.pointwise_conv.weight',
                    'op_after_combine9/conv/bias':
                    'conv4_down.pointwise_conv.bias',
                    'op_after_combine9/bn/beta':
                    'conv4_down.bn.bias',
                    'op_after_combine9/bn/gamma':
                    'conv4_down.bn.weight',
                    'op_after_combine9/bn/moving_mean':
                    'conv4_down.bn.running_mean',
                    'op_after_combine9/bn/moving_variance':
                    'conv4_down.bn.running_var',
                }
                if fpn_idx == 0:
                    mapping = {
                        'resample_0_1_9/conv2d/kernel':
                        'p4_level_connection.down_conv.weight',
                        'resample_0_1_9/conv2d/bias':
                        'p4_level_connection.down_conv.bias',
                        'resample_0_1_9/bn/beta':
                        'p4_level_connection.bn.bias',
                        'resample_0_1_9/bn/gamma':
                        'p4_level_connection.bn.weight',
                        'resample_0_1_9/bn/moving_mean':
                        'p4_level_connection.bn.running_mean',
                        'resample_0_1_9/bn/moving_variance':
                        'p4_level_connection.bn.running_var',
                    }
                    base_mapping.update(mapping)
                if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
                    suffix = base_mapping['/'.join(seg[3:])]
                    if 'depthwise_conv' in suffix:
                        v = v.transpose(1, 0)
                    m[prefix + '.' + suffix] = v
                elif seg[3] == 'WSM':
                    p4_w2[fpn_idx][0] = v
                elif seg[3] == 'WSM_1':
                    p4_w2[fpn_idx][1] = v
                elif seg[3] == 'WSM_2':
                    p4_w2[fpn_idx][2] = v
                if torch.min(p4_w2[fpn_idx]) > -1e4:
                    m[prefix + '.p4_w2'] = p4_w2[fpn_idx]
            elif fnode_id == 5:
                base_mapping = {
                    'op_after_combine10/conv/depthwise_kernel':
                    'conv5_down.depthwise_conv.weight',
                    'op_after_combine10/conv/pointwise_kernel':
                    'conv5_down.pointwise_conv.weight',
                    'op_after_combine10/conv/bias':
                    'conv5_down.pointwise_conv.bias',
                    'op_after_combine10/bn/beta':
                    'conv5_down.bn.bias',
                    'op_after_combine10/bn/gamma':
                    'conv5_down.bn.weight',
                    'op_after_combine10/bn/moving_mean':
                    'conv5_down.bn.running_mean',
                    'op_after_combine10/bn/moving_variance':
                    'conv5_down.bn.running_var',
                }
                if fpn_idx == 0:
                    mapping = {
                        'resample_0_2_10/conv2d/kernel':
                        'p5_level_connection.down_conv.weight',
                        'resample_0_2_10/conv2d/bias':
                        'p5_level_connection.down_conv.bias',
                        'resample_0_2_10/bn/beta':
                        'p5_level_connection.bn.bias',
                        'resample_0_2_10/bn/gamma':
                        'p5_level_connection.bn.weight',
                        'resample_0_2_10/bn/moving_mean':
                        'p5_level_connection.bn.running_mean',
                        'resample_0_2_10/bn/moving_variance':
                        'p5_level_connection.bn.running_var',
                    }
                    base_mapping.update(mapping)
                if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
                    suffix = base_mapping['/'.join(seg[3:])]
                    if 'depthwise_conv' in suffix:
                        v = v.transpose(1, 0)
                    m[prefix + '.' + suffix] = v
                elif seg[3] == 'WSM':
                    p5_w2[fpn_idx][0] = v
                elif seg[3] == 'WSM_1':
                    p5_w2[fpn_idx][1] = v
                elif seg[3] == 'WSM_2':
                    p5_w2[fpn_idx][2] = v
                if torch.min(p5_w2[fpn_idx]) > -1e4:
                    m[prefix + '.p5_w2'] = p5_w2[fpn_idx]
            elif fnode_id == 6:
                base_mapping = {
                    'op_after_combine11/conv/depthwise_kernel':
                    'conv6_down.depthwise_conv.weight',
                    'op_after_combine11/conv/pointwise_kernel':
                    'conv6_down.pointwise_conv.weight',
                    'op_after_combine11/conv/bias':
                    'conv6_down.pointwise_conv.bias',
                    'op_after_combine11/bn/beta':
                    'conv6_down.bn.bias',
                    'op_after_combine11/bn/gamma':
                    'conv6_down.bn.weight',
                    'op_after_combine11/bn/moving_mean':
                    'conv6_down.bn.running_mean',
                    'op_after_combine11/bn/moving_variance':
                    'conv6_down.bn.running_var',
                }
                if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
                    suffix = base_mapping['/'.join(seg[3:])]
                    if 'depthwise_conv' in suffix:
                        v = v.transpose(1, 0)
                    m[prefix + '.' + suffix] = v
                elif seg[3] == 'WSM':
                    p6_w2[fpn_idx][0] = v
                elif seg[3] == 'WSM_1':
                    p6_w2[fpn_idx][1] = v
                elif seg[3] == 'WSM_2':
                    p6_w2[fpn_idx][2] = v
                if torch.min(p6_w2[fpn_idx]) > -1e4:
                    m[prefix + '.p6_w2'] = p6_w2[fpn_idx]
            elif fnode_id == 7:
                base_mapping = {
                    'op_after_combine12/conv/depthwise_kernel':
                    'conv7_down.depthwise_conv.weight',
                    'op_after_combine12/conv/pointwise_kernel':
                    'conv7_down.pointwise_conv.weight',
                    'op_after_combine12/conv/bias':
                    'conv7_down.pointwise_conv.bias',
                    'op_after_combine12/bn/beta':
                    'conv7_down.bn.bias',
                    'op_after_combine12/bn/gamma':
                    'conv7_down.bn.weight',
                    'op_after_combine12/bn/moving_mean':
                    'conv7_down.bn.running_mean',
                    'op_after_combine12/bn/moving_variance':
                    'conv7_down.bn.running_var',
                }
                if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
                    suffix = base_mapping['/'.join(seg[3:])]
                    if 'depthwise_conv' in suffix:
                        v = v.transpose(1, 0)
                    m[prefix + '.' + suffix] = v
                elif seg[3] == 'WSM':
                    p7_w2[fpn_idx][0] = v
                elif seg[3] == 'WSM_1':
                    p7_w2[fpn_idx][1] = v
                if torch.min(p7_w2[fpn_idx]) > -1e4:
                    m[prefix + '.p7_w2'] = p7_w2[fpn_idx]
        elif seg[0] == 'box_net':
            if 'box-predict' in seg[1]:
                prefix = '.'.join(['bbox_head', 'reg_header'])
                base_mapping = {
                    'depthwise_kernel': 'depthwise_conv.weight',
                    'pointwise_kernel': 'pointwise_conv.weight',
                    'bias': 'pointwise_conv.bias'
                }
                suffix = base_mapping['/'.join(seg[2:])]
                if 'depthwise_conv' in suffix:
                    v = v.transpose(1, 0)
                m[prefix + '.' + suffix] = v
            elif 'bn' in seg[1]:
                bbox_conv_idx = int(seg[1][4])
                bbox_bn_idx = int(seg[1][9]) - 3
                prefix = '.'.join([
                    'bbox_head', 'reg_bn_list',
                    str(bbox_conv_idx),
                    str(bbox_bn_idx)
                ])
                base_mapping = {
                    'beta': 'bias',
                    'gamma': 'weight',
                    'moving_mean': 'running_mean',
                    'moving_variance': 'running_var'
                }
                suffix = base_mapping['/'.join(seg[2:])]
                m[prefix + '.' + suffix] = v
            else:
                bbox_conv_idx = int(seg[1][4])
                prefix = '.'.join(
                    ['bbox_head', 'reg_conv_list',
                     str(bbox_conv_idx)])
                base_mapping = {
                    'depthwise_kernel': 'depthwise_conv.weight',
                    'pointwise_kernel': 'pointwise_conv.weight',
                    'bias': 'pointwise_conv.bias'
                }
                suffix = base_mapping['/'.join(seg[2:])]
                if 'depthwise_conv' in suffix:
                    v = v.transpose(1, 0)
                m[prefix + '.' + suffix] = v
        elif seg[0] == 'class_net':
            if 'class-predict' in seg[1]:
                prefix = '.'.join(['bbox_head', 'cls_header'])
                base_mapping = {
                    'depthwise_kernel': 'depthwise_conv.weight',
                    'pointwise_kernel': 'pointwise_conv.weight',
                    'bias': 'pointwise_conv.bias'
                }
                suffix = base_mapping['/'.join(seg[2:])]
                if 'depthwise_conv' in suffix:
                    v = v.transpose(1, 0)
                m[prefix + '.' + suffix] = v
            elif 'bn' in seg[1]:
                cls_conv_idx = int(seg[1][6])
                cls_bn_idx = int(seg[1][11]) - 3
                prefix = '.'.join([
                    'bbox_head', 'cls_bn_list',
                    str(cls_conv_idx),
                    str(cls_bn_idx)
                ])
                base_mapping = {
                    'beta': 'bias',
                    'gamma': 'weight',
                    'moving_mean': 'running_mean',
                    'moving_variance': 'running_var'
                }
                suffix = base_mapping['/'.join(seg[2:])]
                m[prefix + '.' + suffix] = v
            else:
                cls_conv_idx = int(seg[1][6])
                prefix = '.'.join(
                    ['bbox_head', 'cls_conv_list',
                     str(cls_conv_idx)])
                base_mapping = {
                    'depthwise_kernel': 'depthwise_conv.weight',
                    'pointwise_kernel': 'pointwise_conv.weight',
                    'bias': 'pointwise_conv.bias'
                }
                suffix = base_mapping['/'.join(seg[2:])]
                if 'depthwise_conv' in suffix:
                    v = v.transpose(1, 0)
                m[prefix + '.' + suffix] = v
    return m


def parse_args():
    parser = argparse.ArgumentParser(
        description='convert efficientdet weight from tensorflow to pytorch')
    parser.add_argument(
        '--backbone',
        type=str,
        help='efficientnet model name, like efficientnet-b0')
    parser.add_argument(
        '--tensorflow_weight',
        type=str,
        help='efficientdet tensorflow weight name, like efficientdet-d0/model')
    parser.add_argument(
        '--out_weight',
        type=str,
        help='efficientdet pytorch weight name like demo.pth')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    model_name = args.backbone
    ori_weight_name = args.tensorflow_weight
    out_name = args.out_weight

    repeat_map = {
        0: 3,
        1: 4,
        2: 5,
        3: 6,
        4: 7,
        5: 7,
        6: 8,
        7: 8,
    }

    reader = py_checkpoint_reader.NewCheckpointReader(ori_weight_name)
    weights = {
        n: torch.as_tensor(tf2pth(reader.get_tensor(n)))
        for (n, _) in reader.get_variable_to_shape_map().items()
    }
    bifpn_repeats = repeat_map[int(model_name[14])]
    out = convert_key(model_name, bifpn_repeats, weights)
    result = {'state_dict': out}
    torch.save(result, out_name)


if __name__ == '__main__':
    main()