# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import os.path as osp
from argparse import ArgumentParser

from mmengine.config import Config, DictAction
from mmengine.logging import MMLogger, print_log
from mmengine.registry import RUNNERS
from mmengine.runner import Runner

from mmdet.testing import replace_to_ceph
from mmdet.utils import register_all_modules, replace_cfg_vals


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('config', help='test config file path')
    parser.add_argument('--work-dir', help='the dir to save logs and models')
    parser.add_argument('--ceph', action='store_true')
    parser.add_argument('--save-ckpt', action='store_true')
    parser.add_argument(
        '--amp',
        action='store_true',
        default=False,
        help='enable automatic-mixed-precision training')
    parser.add_argument(
        '--auto-scale-lr',
        action='store_true',
        help='enable automatically scaling LR.')
    parser.add_argument(
        '--resume',
        action='store_true',
        help='resume from the latest checkpoint in the work_dir automatically')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    args = parser.parse_args()
    return args


# TODO: Need to refactor train.py so that it can be reused.
def fast_train_model(config_name, args, logger=None):
    cfg = Config.fromfile(config_name)
    cfg = replace_cfg_vals(cfg)
    cfg.launcher = args.launcher
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = osp.join(args.work_dir,
                                osp.splitext(osp.basename(config_name))[0])
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(config_name))[0])

    ckpt_hook = cfg.default_hooks.checkpoint
    by_epoch = ckpt_hook.get('by_epoch', True)
    fast_stop_hook = dict(type='FastStopTrainingHook')
    fast_stop_hook['by_epoch'] = by_epoch
    if args.save_ckpt:
        if by_epoch:
            interval = 1
            stop_iter_or_epoch = 2
        else:
            interval = 4
            stop_iter_or_epoch = 10
        fast_stop_hook['stop_iter_or_epoch'] = stop_iter_or_epoch
        fast_stop_hook['save_ckpt'] = True
        ckpt_hook.interval = interval

    if 'custom_hooks' in cfg:
        cfg.custom_hooks.append(fast_stop_hook)
    else:
        custom_hooks = [fast_stop_hook]
        cfg.custom_hooks = custom_hooks

    # TODO: temporary plan
    if 'visualizer' in cfg:
        if 'name' in cfg.visualizer:
            del cfg.visualizer.name

    # enable automatic-mixed-precision training
    if args.amp is True:
        optim_wrapper = cfg.optim_wrapper.type
        if optim_wrapper == 'AmpOptimWrapper':
            print_log(
                'AMP training is already enabled in your config.',
                logger='current',
                level=logging.WARNING)
        else:
            assert optim_wrapper == 'OptimWrapper', (
                '`--amp` is only supported when the optimizer wrapper type is '
                f'`OptimWrapper` but got {optim_wrapper}.')
            cfg.optim_wrapper.type = 'AmpOptimWrapper'
            cfg.optim_wrapper.loss_scale = 'dynamic'

    # enable automatically scaling LR
    if args.auto_scale_lr:
        if 'auto_scale_lr' in cfg and \
                'enable' in cfg.auto_scale_lr and \
                'base_batch_size' in cfg.auto_scale_lr:
            cfg.auto_scale_lr.enable = True
        else:
            raise RuntimeError('Can not find "auto_scale_lr" or '
                               '"auto_scale_lr.enable" or '
                               '"auto_scale_lr.base_batch_size" in your'
                               ' configuration file.')

    if args.ceph:
        replace_to_ceph(cfg)

    cfg.resume = args.resume

    # build the runner from config
    if 'runner_type' not in cfg:
        # build the default runner
        runner = Runner.from_cfg(cfg)
    else:
        # build customized runner from the registry
        # if 'runner_type' is set in the cfg
        runner = RUNNERS.build(cfg)

    runner.train()


# Sample test whether the train code is correct
def main(args):
    # register all modules in mmdet into the registries
    register_all_modules(init_default_scope=False)

    config = Config.fromfile(args.config)

    # test all model
    logger = MMLogger.get_instance(
        name='MMLogger',
        log_file='benchmark_train.log',
        log_level=logging.ERROR)

    for model_key in config:
        model_infos = config[model_key]
        if not isinstance(model_infos, list):
            model_infos = [model_infos]
        for model_info in model_infos:
            print('processing: ', model_info['config'], flush=True)
            config_name = model_info['config'].strip()
            try:
                fast_train_model(config_name, args, logger)
            except RuntimeError as e:
                # quick exit is the normal exit message
                if 'quick exit' not in repr(e):
                    logger.error(f'{config_name} " : {repr(e)}')
            except Exception as e:
                logger.error(f'{config_name} " : {repr(e)}')


if __name__ == '__main__':
    args = parse_args()
    main(args)