#!/usr/bin/env python3

# ----------------------------------------------------------------------
# Created: fre jun 18 02:03:33 2021 (+0200)
# Last-Updated:
# Filename: chef
# Author: Yinan Yu
# Description:
# ----------------------------------------------------------------------

import os, sys, glob
import ast
from pprint import pprint
import argparse

from alex.alex import compare, const, util, dsl_parser
from alex.annotators import code_gen


def diff(configs, to_png=True, path="./diff.png", verbose=True, dpi=None):
    if to_png:
        path = path
    else:
        path = None

    if isinstance(verbose, str):
        verbose = ast.literal_eval(verbose)
    config1, config2 = configs
    cost, operations = compare.diff(config1, config2, render_to=path, dpi=dpi)
    if verbose:
        for operation in operations:
            if operation[1][0] != "MATCH":
                print(operation)
    if verbose:
        print("Set --verbose False to avoid output")


def ls(option):
    # paths = sorted(glob.glob(const.COMPONENT_BASE_PATH+"/[!.]*.yml"))
    # components = list(map(lambda x: x.split("/")[-1].split(".")[0], paths))

    if option == "components":
        pprint(list(filter(lambda x: x!="root", const.ALL_COMPONENTS)))
    elif option == "recipes":
        pprint(list(filter(lambda x: x!="root", const.ALL_RECIPES)))
    elif option == "ingredients":
        pprint(list(const.INGREDIENT_TYPES.keys()))
    elif option == "initializers":
        pprint(list(const.ALL_INITIALIZERS.keys()))
    elif option == "regularizers":
        pprint(list(const.REGULARIZERS.keys()))
    elif option == "optimizers":
        pprint(list(const.OPTIMIZERS.keys()))
    elif option == "losses":
        pprint(list(const.LOSSES.keys()))
    elif option == "functions":
        print("type", "name", "inputs", "repeat", "hyperparams", "visible")


def inspect(component):
    config = os.path.join(const.COMPONENT_BASE_PATH,
                          component+".yml")
    if os.path.exists(config):
        hyperparam = util.read_yaml(config)
        pprint(hyperparam)
    else:
        if component in const.ALL_COMPONENTS:
            print("Ingredient %s has no hyperparameter! Alex loves those!" % component)
        else:
            print("Oh no ingredient %s does not exist in Alex yet!\n"
                  "Help us improve please?" % component)


def codegen(engine, config, out_dir, filename):
    if not config:
        print("Must give a network configuration --config")
        sys.exit()
    try:
        code_generator = code_gen.CodeGen(filename,
                                          config,
                                          engine=engine,
                                          dirname=out_dir)
        code_generator.generate_python()
        print("Generated code is written to file %s in dir %s" % (filename, out_dir))
    except Exception as err:
        print(err)


def render(config, type, path, level):
    if not config:
        print("Must give a network configuration --config example.yml")
        sys.exit()

    if type == "ast":
        dsl_parser.make_ast_from_yml(config, path)

    elif type == "graph":
        dsl_parser.make_graph_from_yml(config, path, level)

    print("Image written to file %s" % path)


def merge(config, ckpt):
    pass


def main(fn, kwargs=dict()):
    fns[fn](**kwargs)


if __name__=="__main__":
    if len(sys.argv)<2:
        print("Choose one of the following: diff, ls, inspect, codegen, render; or -h for help")
        sys.exit()
    fn = sys.argv[1]
    fns = {"diff": diff,
           "ls": ls,
           "inspect": inspect,
           "codegen": codegen,
           "render": render}

    parser = argparse.ArgumentParser(description="Alex network analyzer")
    subparsers = parser.add_subparsers(help="")

    diff_parser = subparsers.add_parser("diff",
                                        help="Change log between two networks, e.g. alex-nn diff example_config_1.yml example_config_2.yml")
    diff_parser.add_argument("configs", metavar="Networks to compare", type=str, nargs=2,
                             help="The orignal and modified network configurations")

    diff_parser.add_argument("--to_png", metavar="Save diff to a png file",
                             type=str, nargs="?", default=True,
                             help="Save diff to a png file in the specified path")

    diff_parser.add_argument("--path", metavar="Path",
                             type=str, nargs="?", default="./diff.png",
                             help="Path where the png file goes")

    diff_parser.add_argument("--verbose", metavar="Print the diffs",
                             type=str, nargs="?", default=True,
                             help="Show the diffs")

    diff_parser.add_argument("--dpi", metavar="dpi of the image",
                             default=800,
                             type=int, nargs="?",
                             help="Resolution of the image")


    ls_parser = subparsers.add_parser("ls", help="List information, e.g. alex-nn ls [functions, coponents, recipes, ingredients, initializers, regularizers, optimizers, losses]")
    ls_parser.add_argument("option", metavar="option", type=str, nargs="?",
                           default="components",
                           help="What to ls?")

    inspect_parser = subparsers.add_parser("inspect", help="Inspect hyperparameters of an ingredient, e.g. alex-nn inspect conv")
    inspect_parser.add_argument("component", metavar="choose a component to inspect, e.g. conv, relu, etc",
                                type=str,
                                help="Which component to inspect?")

    codegen_parser = subparsers.add_parser("codegen", help="Generate python code, e.g. alex-nn codegen example_config.yml")

    codegen_parser.add_argument("--engine",
                                default="pytorch",
                                metavar="Currently support tf and pytorch",
                                type=str,
                                help="Which framework?")

    codegen_parser.add_argument("--config", metavar="Network configuration file path",
                                type=str,
                                help="Network configuration")

    codegen_parser.add_argument("--out_dir", metavar="Output dir", default="./",
                                type=str,
                                help="Python file will be written to this dir")

    codegen_parser.add_argument("--filename", metavar="Output file name",
                                type=str,
                                default="generated.py",
                                help="Python file name")

    render_parser = subparsers.add_parser("render",
                                          help="Render ast or graph from a config")

    render_parser.add_argument("config",
                               metavar="Network configuration file path",
                               type=str,
                               help="Network configuration")

    render_parser.add_argument("--type", metavar="ast/graph",
                               type=str,
                               nargs="?",
                               default="ast",
                               help="What to render?")

    render_parser.add_argument("--path", metavar="Path",
                               type=str, nargs="?", default="./network.png",
                               help="Path where the png file goes")

    render_parser.add_argument("--level", metavar="Level of the graph",
                               type=int, nargs="?", default=2,
                               help="For graph you can hide the details of the recipe")

    args = parser.parse_args()
    main(fn, vars(args))
