#!/usr/bin/env python3

# ----------------------------------------------------------------------
# Created: fre jun 18 02:03:33 2021 (+0200)
# Last-Updated:
# Filename: chef
# Author: Yinan Yu
# Description:
# ----------------------------------------------------------------------

import os
import sys
import ast
from pprint import pprint
import argparse
import subprocess

from alex.alex import compare, const, util, dsl_parser, registry
from alex.alex.logger import logger
from alex.annotators import code_gen


def diff(configs, to_png, verbose, dpi, mode):
    if to_png:
        path = to_png
    else:
        path = None

    if isinstance(verbose, str):
        verbose = ast.literal_eval(verbose)
    config1, config2 = configs
    logger.info("Comparing %s and %s" % (config1, config2))
    if mode == "diff":
        cost, operations = compare.diff(config1, config2, render_to=path, dpi=dpi)
    elif mode == "dist":
        cost, operations = compare.dist(config1, config2, render_to=path, dpi=dpi)
    else:
        raise Exception("Mode %s not recgonized" % mode)
    if verbose:
        logger.info("Diffs are printed below")
        for operation in operations:
            if operation[1][0] != "MATCH":
                pprint(operation)
    if verbose:
        logger.info("Set --verbose False to avoid output")


def ls(option):
    # paths = sorted(glob.glob(const.COMPONENT_BASE_PATH+"/[!.]*.yml"))
    # components = list(map(lambda x: x.split("/")[-1].split(".")[0], paths))

    if option == "components":
        pprint(list(filter(lambda x: x!="root", registry.ALL_COMPONENTS)))
    elif option == "recipes":
        pprint(list(filter(lambda x: x!="root", registry.RECIPE_TYPES)))
    elif option == "ingredients":
        pprint(list(registry.INGREDIENT_TYPES.keys()))
    elif option == "initializers":
        pprint(list(registry.ALL_INITIALIZERS.keys()))
    elif option == "regularizers":
        pprint(list(registry.REGULARIZERS.keys()))
    elif option == "optimizers":
        pprint(list(registry.OPTIMIZER_INGREDIENTS.keys()))
    elif option == "losses":
        pprint(list(registry.LOSSES.keys()))
    elif option == "functions":
        print("type", "name", "inputs", "repeat", "hyperparams", "visible")


def inspect(component):
    config = os.path.join(const.COMPONENT_BASE_PATH,
                          component+".yml")
    if os.path.exists(config):
        hyperparam = util.read_yaml(config)
        pprint(hyperparam)
    else:
        if component in registry.ALL_COMPONENTS:
            logger.info("Function %s has no hyperparameter! Alex loves those!" % component)
        else:
            logger.error("Oh no ingredient %s does not exist in Alex yet!\n"
                         "Help us improve please?" % component)


def codegen(engine, config, out_dir, filename,
            ckpt_from=None, ckpt_to=None, run_code=False, append=None,
            docker_image=None):
    if not docker_image:
        docker_image = engine
    if not config:
        logger.error("Must give a network configuration")
        sys.exit()
    else:
        config = config[0]

    try:
        if ckpt_from is not None:
            load_dir = "/".join(ckpt_from.split("/")[:-1])
            load_ckpt = ckpt_from.split("/")[-1]
            ckpt_from = [load_dir, load_ckpt]
        if ckpt_to is not None:
            save_dir = "/".join(ckpt_to.split("/")[:-1])
            save_ckpt = ckpt_to.split("/")[-1]
            ckpt_to = [save_dir, save_ckpt]

        filepath = os.path.join(out_dir, filename)
        code_gen.generate_python(filepath,
                                 config,
                                 engine=engine,
                                 dirname=out_dir,
                                 load_ckpt=ckpt_from,
                                 save_ckpt=ckpt_to,
                                 def_only=False)
        logger.info("Generated code is written to file %s" % filepath)
    except Exception as err:
        logger.error(err)
    if append:
        util.concatenate_files([filepath,
                                append],
                               filepath)
    if run_code:
        try:
            if ckpt_from is not None:
                mount_load_dir = "-v $PWD/%s:/ws/" % load_dir
            else:
                mount_load_dir = ""
            command = "docker run --shm-size=5g --gpus all -v $PWD:/ws/ %s -w /ws/ --rm %s python %s" % (mount_load_dir,
                                                                                                         engine, filepath)

            subprocess.run(command, shell=True, check=True)
        except Exception as err:
            logger.info(command)
            logger.error(err)


def render(config, type, path, level):
    if not config:
        logger.error("Must give a network configuration")
        sys.exit()
    else:
        config = config[0]

    if type == "ast":
        dsl_parser.make_ast_from_yml(config, path)

    elif type == "graph":
        dsl_parser.make_graph_from_yml(config, path, level)

    logger.info("Image written to file %s" % path)


def main(fn, kwargs=dict()):
    fns[fn](**kwargs)


if __name__=="__main__":

    if len(sys.argv) < 2:
        logger.warning("Choose one of the following: diff, ls, inspect, codegen, render; or -h for help")
        sys.exit()
    fn = sys.argv[1]
    fns = {"diff": diff,
           "ls": ls,
           "inspect": inspect,
           "codegen": codegen,
           "render": render}

    parser = argparse.ArgumentParser(description="Alex network analyzer")
    subparsers = parser.add_subparsers(help="")

    diff_parser = subparsers.add_parser("diff",
                                        help="Change log between two networks, e.g. alex-nn diff example_config_1.yml example_config_2.yml")
    diff_parser.add_argument("configs", metavar="Networks to compare", type=str, nargs=2,
                             help="The orignal and modified network configurations")

    diff_parser.add_argument("--to_png", metavar="Save diff to a png file",
                             type=str, nargs="?", default=True,
                             help="Save diff to the specified png file")

    diff_parser.add_argument("--verbose", metavar="Print the diffs (True/False)",
                             type=str, nargs="?", default=True,
                             help="Show the diffs")

    diff_parser.add_argument("--mode", metavar="diff/dist",
                             type=str, nargs="?", default="diff",
                             help="Difference or distance")

    diff_parser.add_argument("--dpi", metavar="dpi of the image",
                             default=800,
                             type=int, nargs="?",
                             help="Resolution of the image")

    ls_parser = subparsers.add_parser("ls", help="List information, e.g. alex-nn ls [functions, coponents, recipes, ingredients, initializers, regularizers, optimizers, losses]")
    ls_parser.add_argument("option", metavar="option", type=str, nargs="?",
                           default="components",
                           help="What to ls?")

    inspect_parser = subparsers.add_parser("inspect", help="Inspect hyperparameters of an ingredient, e.g. alex-nn inspect conv")
    inspect_parser.add_argument("component", metavar="choose a component to inspect, e.g. conv, relu, etc",
                                type=str,
                                help="Which component to inspect?")

    codegen_parser = subparsers.add_parser("codegen", help="Generate python code, e.g. alex-nn codegen example_config.yml")

    codegen_parser.add_argument("config",
                                metavar="Network configuration file path",
                                type=str,
                                nargs=1,
                                help="Network configuration")

    codegen_parser.add_argument("--engine",
                                default="pytorch",
                                metavar="Currently support tf and pytorch",
                                type=str,
                                help="Which framework?")

    codegen_parser.add_argument("--out_dir", metavar="Output dir", default="./",
                                type=str,
                                help="Python file will be written to this dir")

    codegen_parser.add_argument("--filename", metavar="Output file name",
                                type=str,
                                default="generated.py",
                                help="Python file name (not the full path)")

    codegen_parser.add_argument("--ckpt_from", metavar="Checkpoint path to load from",
                                type=str,
                                help="Load from checkpoint")

    codegen_parser.add_argument("--ckpt_to", metavar="Checkpoint path to save to",
                                type=str,
                                help="Save to checkpoint")

    codegen_parser.add_argument("--run", dest="run_code", action="store_true", default=False)

    codegen_parser.add_argument("--append", type=str, metavar="Append file to generated code")

    codegen_parser.add_argument("--docker_image", type=str, metavar="The tag of the docker image. Set to engine (e.g. pytorch, tf, etc) if tag is not set")

    render_parser = subparsers.add_parser("render",
                                          help="Render ast or graph from a config")

    render_parser.add_argument("config",
                               metavar="Network configuration file path",
                               nargs=1,
                               type=str,
                               help="Network configuration")

    render_parser.add_argument("--type", metavar="ast/graph",
                               type=str,
                               nargs="?",
                               default="ast",
                               help="What to render?")

    render_parser.add_argument("--path", metavar="Path",
                               type=str, nargs="?", default="./network.png",
                               help="Path where the png file goes")

    render_parser.add_argument("--level", metavar="Level of the graph",
                               type=int, nargs="?", default=2,
                               help="For graph you can hide the details of the recipe")

    args = parser.parse_args()
    main(fn, vars(args))
