#!/usr/bin/env python

import argparse
import logging
import os
import requests
import sys


base_path = os.path.abspath(os.getcwd())
repo_template_catalog = os.path.join(base_path, "model_templates/")
model_catalog = os.path.join(base_path, "model_definitions/")
available_modes = ['Train', 'Evaluate', 'Score (Batch)']


def validate_cwd_valid():
    return os.path.exists(model_catalog)


def init_model_directory(args, model_util, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    model_util.init_model_directory()

    if not model_util.repo_config_exists():
        configure_repo(model_util)


def input_string(name, required=False, tooltip='', password=False, called_from_test=False):
    from getpass import getpass

    if tooltip != '':
        print(tooltip)

    value = getpass("Enter {}: ".format(name)) if password and not called_from_test else input("Enter {}: ".format(name))
    if value == '' and required:
        print('Value required. Please try again.')
        print('You may cancel at anytime by pressing Ctrl+C')
        print("")
        return input_string(name, required, tooltip)
    return value


def input_select(name, values, label='', default=None):
    if len(values) == 0:
        return

    if label != '':
        print_underscored(label)

    for ix, item in enumerate(values):
        default_text = " (default)" if default and default == item else ''
        print("[{}] {}".format(ix, item) + default_text)

    tmp_index = input("Select {} by index (or leave blank for the default one): ".format(name)) if default else input(
        "Select {} by index: ".format(name))

    if default and default in values and tmp_index == '':
        tmp_index = values.index(default)
    elif (tmp_index == '' and not default) or (not tmp_index.isnumeric() or int(tmp_index) >= len(values)):
        print('Wrong selection, please try again by selecting the index number on the first column.')
        print('You may cancel at anytime by pressing Ctrl+C.')
        print("")
        return input_select(name, values, label, default)

    return values[int(tmp_index)]


def add_model(args, model_util, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    if not validate_cwd_valid():
        logging.error("Current working directory must the root of a model repository to execute this command")
        exit(1)

    model_name = input_string('model name', True)
    model_desc = input_string('model description', False)

    catalog, custom_catalog = model_util.get_all_template_catalog()

    lang = sorted(list(set(list(catalog.keys())).union(list(custom_catalog.keys()))))

    print("")
    model_lang = input_select("model language", lang, "Supported languages:")

    if model_lang in catalog and model_lang in custom_catalog:
        models = sorted(list(set(list(catalog[model_lang])).union(list(custom_catalog[model_lang]))))
    elif model_lang in custom_catalog:
        models = sorted(list(set(list(custom_catalog[model_lang]))))
    else:
        models = sorted(list(set(list(catalog[model_lang]))))

    print("")
    model_template = input_select("template type", models, "Supported templates:", 'empty')

    if model_lang in custom_catalog and model_template in custom_catalog[model_lang]:
        model_util.add_model(
            model_name=model_name,
            model_desc=model_desc,
            lang=model_lang,
            template=model_template,
            base_path=base_path
        )
        logging.info("Creating model structure for model: ({}) {}".format(model_util.model["id"], model_util.model["name"]))
    elif model_lang in catalog and model_template in catalog[model_lang]:
        model_util.add_model(
            model_name=model_name,
            model_desc=model_desc,
            lang=model_lang,
            template=model_template,
            base_path=base_path
        )
        logging.info("Creating model structure for model: ({}) {}".format(model_util.model["id"], model_util.model["name"]))
    else:
        logging.error("Unable to find the specified model template")
        exit(1)


def run_model(args, model_util, **kwargs):
    from aoa import TrainModel
    from aoa import EvaluateModel
    from aoa import ScoreModel
    from aoa import AoaClient
    from aoa import DatasetApi
    import json

    if args.cwd:
        set_cwd(args.cwd)

    current_project = get_current_project(model_util, True)
    if not current_project:
        logging.error("Current working directory must the root of a model repository to execute this command")
        exit(1)

    client = AoaClient()
    client.set_project_id(current_project["id"])

    def _select_model_id(catalog):
        catalog_values = [catalog[i]['name'] for i in catalog]
        selected_model_value = input_select('model', catalog_values, 'Available models:')
        selected_model = next((catalog[i] for i in catalog if catalog[i]['name'] == selected_model_value), None)
        print("")
        return selected_model['id'] if 'id' in selected_model else None

    def _select_mode():
        selected_mode = input_select('mode', available_modes, 'Available modes:').lower()
        print("")
        return selected_mode

    def _select_dataset(dataset_list):
        dataset_values = [i["name"] for i in dataset_list]
        selected_dataset_value = input_select('dataset', dataset_values, 'Available datasets:')
        selected_dataset = next((i for i in dataset_list if i['name'] == selected_dataset_value), None)
        print("")
        return selected_dataset

    available_models = TrainModel.get_model_id(model_catalog, True)

    if not args.model_id:
        model_id = _select_model_id(available_models)
    else:
        model_id_exists = next(
            (available_models[i] for i in available_models if available_models[i]['id'] == args.model_id), False)
        if not model_id_exists:
            print('Model not found. Please select one from the list below or press Ctrl+C to quit.')
        model_id = args.model_id if model_id_exists else _select_model_id(available_models)

    if not args.mode:
        mode = _select_mode()

    else:
        mode_exists = True if args.mode in [x.lower() for x in available_modes] else False
        if not mode_exists:
            print('Mode not found. Please select one from the list below or press Ctrl+C to quit.')
        mode = args.mode if mode_exists else _select_mode()

    if args.local_dataset:
        with open(args.local_dataset, 'r') as f:
            data_conf = json.load(f)

    else:
        dataset_api = DatasetApi(aoa_client=client, show_archived=False)
        available_datasets = list(dataset_api)

        if not args.local_dataset and not args.dataset_id:
            dataset = _select_dataset(available_datasets)
            if not dataset:
                logging.error("No datasets found in project")
                exit(1)

        elif args.dataset_id:
            dataset_exists = next((True for i in available_datasets if i['id'] == args.dataset_id), False)
            if not dataset_exists:
                print('Dataset not found. Please select one from the list below or press Ctrl+C to quit.')
            dataset = dataset_api.find_by_id(args.dataset_id) if dataset_exists else _select_dataset(available_datasets)

        data_conf = dataset["metadata"]

    if mode == "train":
        trainer = TrainModel(model_util)
        trainer.train_model_local(model_id, data_conf=data_conf, base_path=base_path)
    elif mode == "evaluate":
        evaluator = EvaluateModel(model_util)
        evaluator.evaluate_model_local(model_id, data_conf=data_conf, base_path=base_path)
    elif mode == "score (batch)":
        scorer = ScoreModel(model_util)
        scorer.batch_score_model_local(model_id, data_conf=data_conf, base_path=base_path)
    else:
        logging.error("Unsupported mode used: " + mode)
        exit(1)


def list_resources(args, model_util, **kwargs):
    from aoa import TrainModel
    from aoa.api.model_api import ModelApi
    from aoa import AoaClient
    from aoa import DatasetApi

    if args.cwd:
        set_cwd(args.cwd)

    client = AoaClient()

    if args.projects:
        list_and_select_projects(model_util, None, True, False)
        exit(0)

    if args.models:
        project = get_current_project(model_util)
        if project:
            if client:
                client.set_project_id(project["id"])
        else:
            project = list_and_select_projects(model_util, client, False, True)

        model_api = ModelApi(client, show_archived=False)
        print_underscored("List of models for project {}:".format(project["name"]))
        if not len(model_api) > 0:
            print("No models were found")
        for i, model in enumerate(model_api):
            print("[{}] Id: ({}) SourceId: ({}) {}".format(i, model["id"], model["sourceId"] if "sourceId" in model else "NA", model["name"]))

    if args.local_models:
        local_models = TrainModel.get_model_id(model_catalog, True)
        print_underscored("List of local models:")
        if not len(local_models) > 0:
            print("No local models were found")
        for local_model in local_models:
            print(
                "[{}] SourceId: ({}) {}".format(local_model, local_models[local_model]["id"], local_models[local_model]["name"]))

    if args.datasets:
        project = get_current_project(model_util)
        if project:
            if client:
                client.set_project_id(project["id"])
        else:
            project = list_and_select_projects(model_util, client, False, True)

        dataset_api = DatasetApi(aoa_client=client, show_archived=False)
        print_underscored("List of datasets for project {}:".format(project["name"]))
        if not len(dataset_api) > 0:
            print("No datasets were found")
        for i, dataset in enumerate(dataset_api):
            print("[{}] ({}) {}".format(i, dataset["id"], dataset["name"]))

    elif not args.projects and not args.models and not args.local_models and not args.datasets:
        logging.error("Invalid object selection...")
        kwargs.get("parent_parser").print_help()
        exit(1)


def clone(args, model_util, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    if not args.project_id:
        project = list_and_select_projects(model_util, None, False, True)
    else:
        from aoa import AoaClient
        from aoa import ProjectApi

        client = AoaClient()
        project_api = ProjectApi(aoa_client=client, show_archived=False)
        try:
            project = project_api.find_by_id(args.project_id)
        except Exception:
            project = None
        if not project:
            print('Project not found. Please select one from the list below or press Ctrl+C to quit.')
            project = list_and_select_projects(model_util, None, False, True)

    path = args.path if args.path else base_path
    if args.test:
        project_git = project["gitRepositoryUrl"].replace('http://demo-git', 'http://gituser:gitpassword@demo-git')
    else:
        project_git = project["gitRepositoryUrl"]
    model_util.clone_repository(project_git, path, project.get("branch", "master"))
    print("Project cloned at {}".format(path))


def configure(args, model_util, **kwargs):
    from base64 import b64encode
    from pathlib import Path
    import yaml

    if args.cwd:
        set_cwd(args.cwd)

    def _select_auth_mode():
        print("")
        auth_modes = ["basic", "kerberos", "oauth"]
        selected_mode = input_select('authentication mode', auth_modes, 'Supported authentication modes: ', 'basic')
        return selected_mode

    if args.repo:
        configure_repo(model_util)
        print("Project configured.")

    else:
        aoa_url = input_string('API endpoint', True)
        auth_mode = _select_auth_mode()

        config = {
            "aoa_url": aoa_url,
            "auth_mode": auth_mode
        }

        if auth_mode == "basic":
            print("")
            auth_user = input_string('username', True)
            auth_pass = input_string('password', True, '', True, True if args.test else False)

            config["auth_credentials"] = b64encode("{}:{}".format(auth_user, auth_pass).encode()).decode()

        # we don't support oauth_client_credentials via configure as its only meant for machine to machine auth which
        # wouldn't be using an interactive cli anyway
        if auth_mode == "oauth":
            print("")
            config["auth_client_id"] = input_string('client_id', True)
            config["auth_client_secret"] = input_string('client_secret', True, '', True, True if args.test else False)
            config["auth_client_token_url"] = input_string('token_url', True)
            config["auth_client_refresh_token"] = input_string('refresh_token', True)

        config_dir = "{}/.aoa".format(Path.home())
        Path(config_dir).mkdir(parents=True, exist_ok=True)
        config_file = "{}/config.yaml".format(config_dir)

        with open(config_file, 'w+') as f:
            yaml.dump(config, f, default_flow_style=False)
            print("")
            print("New configuration saved at {}".format(config_file))


def message(args, model_util, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    from aoa import AoaClient
    from aoa import MessageApi
    import json

    client = AoaClient()
    message_api = MessageApi(client)

    event = json.loads(args.jobevent)

    if args.jobevent:
        message_api.send_job_event(event)
    elif args.jobprogress:
        message_api.send_progress_event(event)
    else:
        logging.error("Invalid message type...")
        kwargs.get("parent_parser").print_help()
        exit(1)


def set_cwd(path):
    import os
    if not os.path.exists(path):
        logging.error("The specified path does not exist... exiting")
        exit(1)
    os.chdir(path)
    global base_path, repo_template_catalog, model_catalog
    base_path = os.path.abspath(os.getcwd())
    repo_template_catalog = os.path.join(base_path, "model_templates/")
    model_catalog = os.path.join(base_path, "model_definitions/")


def print_help(args, **kwargs):
    from aoa import __version__

    if args.version:
        print("{}".format(__version__))
    else:
        kwargs.get("parent_parser").print_help()


def print_underscored(message):
    print(message)
    print('-' * len(message))


def list_and_select_projects(model_util, client=None, as_list=False, check_config=False):
    from aoa import ProjectApi

    project_api = ProjectApi(aoa_client=client, show_archived=False)
    projects = list(project_api)

    print_underscored('List of projects:' if as_list else 'Available projects:')
    if not len(projects) > 0:
        print('No projects were found')
    for ix, item in enumerate(projects):
        print("[{}] ({}) {}".format(ix, item["id"], item["name"]))
    if as_list:
        return

    current_project = get_current_project(model_util, check_config)
    current_project = current_project if current_project else ""
    current_index = "none"
    for ix, item in enumerate(projects):
        if 'id' in current_project and current_project["id"] == item["id"]:
           current_index = ix
    tmp_index = input("Select project by index (current selection: {}): ".format(current_index))
    print("")

    if ((not tmp_index.isnumeric() or int(tmp_index) >= len(projects)) and not tmp_index == '') or (
            tmp_index == '' and current_index == "none"):
        print('Wrong selection, please try again by selecting the index number on the first column.')
        print("")
        return list_and_select_projects(model_util, client, as_list, check_config)

    if tmp_index == '':
        tmp_index = current_index

    selected_project = projects[int(tmp_index)]

    if client:
        client.set_project_id(selected_project["id"])

    return selected_project


def get_current_project(model_util, check_repo_conf=False):
    from aoa import AoaClient
    from aoa import ProjectApi

    if not validate_cwd_valid():
        return

    repo_conf = model_util.read_repo_config()
    current_project_id = repo_conf["project_id"] if repo_conf and "project_id" in repo_conf else ""
    client = AoaClient()
    project_api = ProjectApi(aoa_client=client, show_archived=False)
    current_project = project_api.find_by_id(current_project_id)

    if current_project:
        return current_project
    elif not current_project and check_repo_conf:
        print('The repository is not properly configured.')
        print('Please execute \'aoa configure --repo\'')
        print("")

    return


def configure_repo(model_util):
    current_project = list_and_select_projects(model_util, None, False, False)
    config = {
        "project_id": current_project['id']
    }
    model_util.write_repo_config(config)


def main():
    try:
        parent_parser = argparse.ArgumentParser(description='AOA CLI')
        parent_parser.add_argument('--debug', action='store_true', help='Enable debug logging')
        parent_parser.add_argument('--version', action='store_true', help='Display the version of this tool')
        parent_parser.set_defaults(func=print_help)

        subparsers = parent_parser.add_subparsers(title="actions", description="valid actions", dest="command")

        parser_list = subparsers.add_parser("list", help="List projects, models, local models or datasets")
        parser_list.add_argument('--debug', action='store_true', help='Enable debug logging')
        parser_list.add_argument('-p', '--projects', action='store_true', help='List projects')
        parser_list.add_argument('-m', '--models', action='store_true',
                                 help='List registered models (committed / pushed)')
        parser_list.add_argument('-lm', '--local-models', action='store_true',
                                 help='List local models. Includes registered and non-registered (non-committed / non-pushed)')
        parser_list.add_argument('-d', '--datasets', action='store_true', help='List datasets')
        parser_list.add_argument('-cwd', '--cwd', type=str, help=argparse.SUPPRESS)
        parser_list.set_defaults(func=list_resources)

        parser_add = subparsers.add_parser("add", help="Add model to working dir")
        parser_add.add_argument('--debug', action='store_true', help='Enable debug logging')
        parser_add.add_argument('-cwd', '--cwd', type=str, help=argparse.SUPPRESS)
        parser_add.set_defaults(func=add_model)

        parser_run = subparsers.add_parser("run", help="Train and Evaluate model locally", )
        parser_run.add_argument('-id', '--model-id', type=str, help='Id of model')
        parser_run.add_argument('-m', '--mode', type=str, help='Mode (train or evaluate)')
        parser_run.add_argument('-d', '--dataset-id', type=str, help='Remote datasetId')
        parser_run.add_argument('-ld', '--local-dataset', type=str, help='Path to local dataset metadata file')
        parser_run.add_argument('--debug', action='store_true', help='Enable debug logging')
        parser_run.add_argument('-cwd', '--cwd', type=str, help=argparse.SUPPRESS)
        parser_run.set_defaults(func=run_model)

        parser_init = subparsers.add_parser("init", help="Initialize model directory with basic structure")
        parser_init.add_argument('--debug', action='store_true', help='Enable debug logging')
        parser_init.add_argument('-cwd', '--cwd', type=str, help=argparse.SUPPRESS)
        parser_init.set_defaults(func=init_model_directory)

        parser_clone = subparsers.add_parser("clone", help="Clone Project Repository")
        parser_clone.add_argument('-id', '--project-id', type=str, help='Id of Project to clone')
        parser_clone.add_argument('-p', '--path', type=str, help='Path to clone repository to')
        parser_clone.add_argument('--debug', action='store_true', help='Enable debug logging')
        parser_clone.add_argument('--test', action='store_true', help=argparse.SUPPRESS)
        parser_clone.add_argument('-cwd', '--cwd', type=str, help=argparse.SUPPRESS)
        parser_clone.set_defaults(func=clone)

        parser_configure = subparsers.add_parser("configure", help="Configure AOA client")
        parser_configure.add_argument('--repo', action='store_true', help='Configure the repo only')
        parser_configure.add_argument('--debug', action='store_true', help='Enable debug logging')
        parser_configure.add_argument('--test', action='store_true', help=argparse.SUPPRESS)
        parser_configure.add_argument('-cwd', '--cwd', type=str, help=argparse.SUPPRESS)
        parser_configure.set_defaults(func=configure)

        parser_message = subparsers.add_parser("message", help="Send a message to AOA message broker")
        parser_message.add_argument('-j', '--jobevent', type=str, help='Send jobevent message')
        parser_message.add_argument('-p', '--jobprogress', type=str, help='Send jobprogress message')
        parser_message.add_argument('--test', action='store_true', help=argparse.SUPPRESS)
        parser_message.add_argument('-cwd', '--cwd', type=str, help=argparse.SUPPRESS)
        parser_message.set_defaults(func=message)

        args = parent_parser.parse_args()

        logging_level = logging.DEBUG if args.debug else logging.INFO
        logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging_level)
        print("")

        from aoa import ModelUtility
        model_util = ModelUtility(base_path, repo_template_catalog)

        args.func(model_util=model_util, args=args, parent_parser=parent_parser)

    except requests.exceptions.ConnectionError as ce:
        logging.error("Please check that the services are running and the client")
        logging.error("is properly configured by executing 'aoa configure'.")
        logging.error("{}".format(ce))
        exit(1)

    except KeyboardInterrupt:
        logging.info("")
        logging.info("Keyboard interrupt... exiting")
        exit(1)

    except Exception as ex:
        if args.debug:
            logging.info("An error occurred, printing stack trace output: ")
            raise
        else:
            logging.error("An error occurred: {}".format(ex))
            exit(1)


if __name__ in ["__main__", "aoa"]:
    main()
