#!/usr/bin/env python

from argparse import ArgumentParser
import logging
import os
import sys


base_path = os.path.abspath(os.getcwd())
repo_template_catalog = os.path.join(base_path, "model_templates/")
model_catalog = os.path.join(base_path, "model_definitions/")


def validate_cwd_valid():
    if not os.path.exists(model_catalog):
        logging.error("Current working directory must the root of a model repository to execute this command")
        exit(1)


def init_model_directory(model_util, **kwargs):
    model_util.init_model_directory()


def add_model(model_util, **kwargs):
    validate_cwd_valid()

    model_name = input("model name: ")
    model_desc = input("model description: ")

    if model_name == "":
        logging.error("specify model name")
        exit(1)

    catalog, custom_catalog = model_util.get_all_template_catalog()

    lang = sorted(list(set(list(catalog.keys())).union(list(custom_catalog.keys()))))

    print("These languages are supported: {0}".format(
        ", ".join(lang)))
    model_lang = input("model language: ")
    if model_lang not in lang:
        logging.error("only {0} model languages currently supported.".format(
            ", ".join(lang)))
        exit(1)

    if model_lang in catalog and model_lang in custom_catalog:
        models = sorted(list(set(list(catalog[model_lang])).union(list(custom_catalog[model_lang]))))
    elif model_lang in custom_catalog:
        models = sorted(list(set(list(custom_catalog[model_lang]))))
    else:
        models = sorted(list(set(list(catalog[model_lang]))))
    print("templates available for {0}: {1}".format(
        model_lang, ", ".join(models)))
    model_template = input("template type (leave blank for the default one): ")
    if not model_template:
        model_template = "empty"
    if model_template not in models:
        logging.error("only {0} templates currently supported.".format(", ".join(models)))
        exit(1)

    if model_lang in custom_catalog and model_template in custom_catalog[model_lang]:
        model_util.add_model(
            model_name=model_name,
            model_desc=model_desc,
            lang=model_lang,
            template=model_template
        )
        logging.info("creating model structure for model: {0}".format(model_util.model["id"]))
    elif model_lang in catalog and model_template in catalog[model_lang]:
        model_util.add_model(
            model_name=model_name,
            model_desc=model_desc,
            lang=model_lang,
            template=model_template
        )
        logging.info("creating model structure for model: {0}".format(model_util.model["id"]))
    else:
        logging.error("unable to find the specified model")
        exit(1)


def run_model(args, model_util, **kwargs):
    from aoa import TrainModel
    from aoa import EvaluateModel
    from aoa import AoaClient
    from aoa import DatasetApi
    import json

    validate_cwd_valid()

    repo_conf = model_util.read_repo_config()
    if not repo_conf:
        logging.error("Please run 'aoa configure --repo'")
        exit(1)

    client = AoaClient()
    client.set_project_id(repo_conf["project_id"])

    if not args.model_id:
        catalog = TrainModel.get_model_id(model_catalog, True)
        for key in catalog:
            print("({}) {}".format(key, catalog[key]["name"]))
        index = input("select model: ")
        if int(index) not in catalog:
            logging.error("invalid selection...exiting")
            exit(1)
        model_id = catalog[int(index)]["id"]
    else:
        model_id = args.model_id

    if not args.mode:
        mode = input("select mode (train, evaluate): ")
        if mode == "":
            logging.error("invalid selection...exiting")
            exit(1)
    else:
        mode = args.mode

    if args.local_dataset:
        with open(args.local_dataset, 'r') as f:
            data_conf = json.load(f)

    else:
        dataset_api = DatasetApi(aoa_client=client)

        if not args.local_dataset and not args.dataset_id:
            datasets = dataset_api.find_all()['_embedded']['datasets']
            if len(datasets) > 0:
                for index, dataset in enumerate(datasets):
                    print("({}) {}".format(index, dataset['name']))
            else:
                logging.error("No datasets found in project")
                exit(1)
            index = input("select dataset metadata: ")
            if int(index) > len(datasets):
                logging.error("invalid selection...exiting")
                exit(1)
            dataset = datasets[int(index)]

        elif args.dataset_id:
            dataset = dataset_api.find_by_id(args.dataset_id)
            if not dataset:
                raise("No dataset found for datasetId: {}".format(args.dataset_id))

        data_conf = dataset["metadata"]

    if mode == "train":
        train = TrainModel(model_util)
        train.train_model_local(model_id, data_conf=data_conf)
    elif mode == "evaluate":
        evaluate = EvaluateModel(model_util)
        evaluate.evaluate_model_local(model_id, data_conf=data_conf)
    else:
        logging.error("unsupported mode used: " + mode)
        exit(1)


def clone(args, model_util, **kwargs):
    from aoa import AoaClient
    from aoa import ProjectApi

    client = AoaClient()
    project_api = ProjectApi(aoa_client=client)

    if not args.project_id:
        projects = project_api.find_all()['_embedded']['projects']

        for i, project in enumerate(projects):
            print("({}) {}".format(i, project["name"]))

        index = input("Select project: ")

        project = projects[int(index)]

    else:
        project = project_api.find_by_id(args.project_id)
        if not project:
            raise ("No project found for id: {}".format(args.project_id))
            exit(1)

    path = args.path if args.path else base_path

    model_util.clone_repository(project["gitRepositoryUrl"], path)


def configure(args, model_util, **kwargs):
    from getpass import getpass
    from base64 import b64encode
    from pathlib import Path
    import yaml

    if args.repo:
        project_id = input("Enter ProjectId: ")

        config = {
            "project_id": project_id
        }

        model_util.write_repo_config(config)

    else:
        aoa_url = input("API Endpoint: ")
        auth_user = input("Username: ")
        auth_pass = getpass("Password: ")

        config = {
            "aoa_url": aoa_url,
            "aoa_credentials": b64encode("{}:{}".format(auth_user, auth_pass).encode()).decode()
        }

        config_dir = "{}/.aoa".format(Path.home())
        Path(config_dir).mkdir(parents=True, exist_ok=True)
        config_file = "{}/config.yaml".format(config_dir)

        with open(config_file, 'w+') as f:
            yaml.dump(config, f, default_flow_style=False)


def print_help(args, **kwargs):
    from aoa import __version__

    if args.version:
        print("{}".format(__version__))
    else:
        kwargs.get("parent_parser").print_help()


def main():
    parent_parser = ArgumentParser(description='AOA CLI')
    parent_parser.add_argument('--debug', action='store_true', help='Enable debug logging')
    parent_parser.add_argument('--version', action='store_true', help='Display the version of this tool')
    parent_parser.set_defaults(func=print_help)

    subparsers = parent_parser.add_subparsers(title="actions", description="valid actions", dest="command")

    parser_add = subparsers.add_parser("add", help="Add model")
    parser_add.add_argument('--debug', action='store_true', help='Enable debug logging')
    parser_add.set_defaults(func=add_model)

    parser_run = subparsers.add_parser("run", help="Train and Evaluate model",)
    parser_run.add_argument('-id', '--model-id', type=str, help='Id of model')
    parser_run.add_argument('-m', '--mode', type=str, help='Mode (train or evaluate)')
    parser_run.add_argument('-d', '--dataset-id', type=str, help='Remote datasetId')
    parser_run.add_argument('-ld', '--local-dataset', type=str, help='Path to local dataset metadata file')
    parser_run.add_argument('--debug', action='store_true', help='Enable debug logging')
    parser_run.set_defaults(func=run_model)

    parser_init = subparsers.add_parser("init", help="Initialize model directory with basic structure")
    parser_init.add_argument('--debug', action='store_true', help='Enable debug logging')
    parser_init.set_defaults(func=init_model_directory)

    parser_clone = subparsers.add_parser("clone", help="Clone Project Repository")
    parser_clone.add_argument('-id', '--project-id', type=str, help='Id of Project to clone')
    parser_clone.add_argument('-p', '--path', type=str, help='Path to clone repository to')
    parser_clone.add_argument('--debug', action='store_true', help='Enable debug logging')
    parser_clone.set_defaults(func=clone)

    parser_configure = subparsers.add_parser("configure", help="Configure AOA client")
    parser_configure.add_argument('--repo', action='store_true', help='Configure the repo only')
    parser_configure.add_argument('--debug', action='store_true', help='Enable debug logging')
    parser_configure.set_defaults(func=configure)

    args = parent_parser.parse_args()

    logging_level = logging.DEBUG if args.debug else logging.INFO
    logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging_level)

    from aoa import ModelUtility
    model_util = ModelUtility(base_path, repo_template_catalog)

    args.func(model_util=model_util, args=args, parent_parser=parent_parser)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        logging.info("Keyboard interrupt...exiting")
    except:
        raise
