#!/usr/bin/env python
# PYTHON_ARGCOMPLETE_OK

import argparse
import argcomplete
import logging
import os
import requests
import sys
import re
import yaml

from pathlib import Path
from aoa.crypto import crypto

base_path = os.path.abspath(os.getcwd())
repo_template_catalog = os.path.join(base_path, "model_templates/")
model_catalog = os.path.join(base_path, "model_definitions/")
available_modes = ['Train', 'Evaluate', 'Score']
config_dir = "{}/.aoa".format(Path.home())


def validate_cwd_valid():
    return os.path.exists(model_catalog)


def init_model_directory(args, repo_manager, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    repo_manager.init_model_directory()

    if not repo_manager.repo_config_exists():
        configure_repo(repo_manager)


def input_string(name, required=False, tooltip='', password=False, called_from_test=False):
    from getpass import getpass

    if tooltip != '':
        print(tooltip)

    value = getpass("Enter {}: ".format(name)) if password and not called_from_test else input("Enter {}: ".format(name))
    if value == '' and required:
        print('Value required. Please try again.')
        print('You may cancel at anytime by pressing Ctrl+C')
        print("")
        return input_string(name, required, tooltip)
    return value


def input_select(name, values, label='', default=None):
    if len(values) == 0:
        return

    if label != '':
        print_underscored(label)

    for ix, item in enumerate(values):
        default_text = " (default)" if default and default == item else ''
        print("[{}] {}".format(ix, item) + default_text)

    tmp_index = input("Select {} by index (or leave blank for the default one): ".format(name)) if default else input(
        "Select {} by index: ".format(name))

    if default and default in values and tmp_index == '':
        tmp_index = values.index(default)
    elif (tmp_index == '' and not default) or (not tmp_index.isnumeric() or int(tmp_index) >= len(values)):
        print('Wrong selection, please try again by selecting the index number on the first column.')
        print('You may cancel at anytime by pressing Ctrl+C.')
        print("")
        return input_select(name, values, label, default)

    return values[int(tmp_index)]


def yes_or_no(question):
    while "the answer is invalid":
        reply = str(input(question + ' (y/n): ')).lower().strip()
        if reply[:1] == 'y':
            return True
        if reply[:1] == 'n':
            return False


def bash_escape(string):
    return string.replace('\\', '\\\\')


def add_model(args, repo_manager, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    if not validate_cwd_valid():
        logging.error("Current working directory must the root of a model repository to execute this command")
        exit(1)

    model_name = input_string('model name', True)
    model_desc = input_string('model description', False)

    if args.template_url:
        import tempfile
        temp_path = tempfile.TemporaryDirectory()
        logging.info("Using this template catalog from Git repo {}".format(args.template_url))
        repo_manager.clone_repository(args.template_url, temp_path.name, args.branch)
        repo_manager.template_catalog = temp_path.name + "/"

    catalog, custom_catalog = repo_manager.get_all_template_catalog()

    lang = sorted(list(set(list(catalog.keys())).union(list(custom_catalog.keys()))))

    print("")
    model_lang = input_select("model language", lang, "Supported languages:")

    if model_lang in catalog and model_lang in custom_catalog:
        models = sorted(list(set(list(catalog[model_lang])).union(list(custom_catalog[model_lang]))))
    elif model_lang in custom_catalog:
        models = sorted(list(set(list(custom_catalog[model_lang]))))
    else:
        models = sorted(list(set(list(catalog[model_lang]))))

    print("")
    model_template = input_select("template type", models, "Supported templates:", 'empty')

    if model_lang in custom_catalog and model_template in custom_catalog[model_lang]:
        repo_manager.add_model(
            model_name=model_name,
            model_desc=model_desc,
            lang=model_lang,
            template=model_template,
            base_path=base_path
        )
        logging.info("Creating model structure for model: ({}) {}".format(repo_manager.model["id"], repo_manager.model["name"]))
    elif model_lang in catalog and model_template in catalog[model_lang]:
        repo_manager.add_model(
            model_name=model_name,
            model_desc=model_desc,
            lang=model_lang,
            template=model_template,
            base_path=base_path
        )
        logging.info("Creating model structure for model: ({}) {}".format(repo_manager.model["id"], repo_manager.model["name"]))
    else:
        logging.error("Unable to find the specified model template")
        exit(1)


def run_model(args, repo_manager, **kwargs):
    from aoa import TrainModel
    from aoa import EvaluateModel
    from aoa import ScoreModel
    from aoa import AoaClient
    from aoa import DatasetApi
    from aoa import DatasetTemplateApi
    import json

    if args.cwd:
        set_cwd(args.cwd)

    current_project = get_current_project(repo_manager, True)
    if not current_project:
        logging.error("Current working directory must the root of a model repository to execute this command")
        exit(1)

    client = AoaClient()
    client.set_project_id(current_project["id"])

    def _select_model_id(catalog):
        catalog_values = [catalog[i]['name'] for i in catalog]
        selected_model_value = input_select('model', catalog_values, 'Available models:')
        selected_model = next((catalog[i] for i in catalog if catalog[i]['name'] == selected_model_value), None)
        print("")
        return selected_model['id'] if 'id' in selected_model else None

    def _select_mode():
        selected_mode = input_select('mode', available_modes, 'Available modes:').lower()
        print("")
        return selected_mode

    def _select_dataset(dataset_list):
        dataset_values = [i["name"] for i in dataset_list]
        selected_dataset_value = input_select('dataset', dataset_values, 'Available datasets:')
        selected_dataset = next((i for i in dataset_list if i['name'] == selected_dataset_value), None)
        print("")
        return selected_dataset

    def _select_dataset_template(template_list):
        template_values = [i["name"] for i in template_list]
        selected_template_value = input_select('dataset template', template_values, 'Available dataset templates:')
        selected_template = next((i for i in template_list if i['name'] == selected_template_value), None)
        print("")
        return selected_template

    def _select_connection(connection=None):
        from argparse import Namespace
        return activate_connection(Namespace(cwd=None, connection=connection))

    available_models = TrainModel.get_model_ids(model_catalog, True)

    if not args.model_id:
        model_id = _select_model_id(available_models)
    else:
        model_id_exists = next(
            (available_models[i] for i in available_models if available_models[i]['id'] == args.model_id), False)
        if not model_id_exists:
            print('Model not found. Please select one from the list below or press Ctrl+C to quit.')
        model_id = args.model_id if model_id_exists else _select_model_id(available_models)

    if not args.mode:
        mode = _select_mode()

    else:
        mode_exists = True if args.mode in [x.lower() for x in available_modes] else False
        if not mode_exists:
            print('Mode not found. Please select one from the list below or press Ctrl+C to quit.')
        mode = args.mode if mode_exists else _select_mode()

    if args.local_dataset:
        with open(args.local_dataset, 'r') as f:
            rendered_dataset = json.load(f)

    elif args.local_dataset_template:
        with open(args.local_dataset_template, 'r') as f:
            rendered_dataset = json.load(f)

    else:
        if mode == "score":
            dataset_template_api = DatasetTemplateApi(aoa_client=client, show_archived=False)

            available_templates = list(dataset_template_api)
            if not args.local_dataset_template and not args.dataset_template_id:
                template = _select_dataset_template(available_templates)
                if not template:
                    logging.error("No dataset templates found in project")
                    exit(1)

            elif args.dataset_template_id:
                template_exists = next((True for i in available_templates if i['id'] == args.dataset_template_id), False)
                if not template_exists:
                    print('Dataset template not found. Please select one from the list below or press Ctrl+C to quit.')
                template = dataset_template_api.find_by_id(args.dataset_template_id) if template_exists else _select_dataset_template(available_templates)

            dataset_template_id = template["id"]

            rendered_dataset = dataset_template_api.render(id=dataset_template_id, scope="score")

        else:
            template_api = DatasetTemplateApi(aoa_client=client, show_archived=False)
            dataset_api = DatasetApi(aoa_client=client, show_archived=False)

            if not args.local_dataset and not args.dataset_id:
                available_templates = list(template_api)
                template = _select_dataset_template(available_templates)
                if not template:
                    logging.error("No dataset template found in project")
                    exit(1)

                from aoa.api.api_iterator import ApiIterator
                from functools import partial
                available_datasets = list(ApiIterator(
                    api_func=partial(dataset_api.find_by_dataset_template_id, template["id"]),
                    entities="datasets"))
                dataset = _select_dataset(available_datasets)
                if not dataset:
                    logging.error("No datasets found in project")
                    exit(1)

            elif args.dataset_id:
                dataset = dataset_api.find_by_id(args.dataset_id)
                if not dataset:
                    print('Dataset not found. Please select one from the list below or press Ctrl+C to quit.')
                    exit(1)

            dataset_id = dataset["id"]

            rendered_dataset = dataset_api.render(dataset_id)

    connection_id = _select_connection(args.connection if args.connection else None)

    print("")
    print("To execute the same command again run:")
    if mode != "score":
        if args.local_dataset:
            print(f"aoa run -m {mode} -id {model_id} -c {connection_id} -ld {args.local_dataset}")
        else:
            print(f"aoa run -m {mode} -id {model_id} -d {dataset_id} -c {connection_id}")
    else:
        if args.local_dataset_template:
            print(f"aoa run -m {mode} -id {model_id} -c {connection_id} -lt {args.local_dataset_template}")
        else:
            print(f"aoa run -m {mode} -id {model_id} -t {dataset_template_id} -c {connection_id}")
    print("")
    print("")

    if mode == "train":
        trainer = TrainModel(repo_manager)
        trainer.train_model_local(model_id, rendered_dataset=rendered_dataset, base_path=base_path)
    elif mode == "evaluate":
        evaluator = EvaluateModel(repo_manager)
        evaluator.evaluate_model_local(model_id, rendered_dataset=rendered_dataset, base_path=base_path)
    elif mode == "score":
        scorer = ScoreModel(repo_manager)
        scorer.batch_score_model_local(model_id, rendered_dataset=rendered_dataset, base_path=base_path)
    else:
        logging.error("Unsupported mode used: " + mode)
        exit(1)


def list_resources(args, repo_manager, **kwargs):
    from aoa import TrainModel
    from aoa.api.model_api import ModelApi
    from aoa import AoaClient
    from aoa import DatasetApi
    from aoa import DatasetTemplateApi

    if args.cwd:
        set_cwd(args.cwd)

    client = AoaClient()

    if args.projects:
        list_and_select_projects(repo_manager, None, True, False)
        exit(0)

    if args.models:
        project = get_current_project(repo_manager)
        if project:
            if client:
                client.set_project_id(project["id"])
        else:
            project = list_and_select_projects(repo_manager, client, False, True)

        model_api = ModelApi(client, show_archived=False)
        print_underscored("List of models for project {}:".format(project["name"]))
        if not len(model_api) > 0:
            print("No models were found")
        for i, model in enumerate(model_api):
            print("[{}] Id: ({}) Type: ({}) {}".format(i, model["id"], model["source"], model["name"]))

    if args.local_models:
        local_models = TrainModel.get_model_folders(model_catalog, True)
        print_underscored("List of local models:")
        if not len(local_models) > 0:
            print("No local models were found")
        for local_model in local_models:
            print("[{}] {} (Git: {})".format(local_model, local_models[local_model]["name"], local_models[local_model]["id"]))

    if args.templates:
        project = get_current_project(repo_manager)
        if project:
            if client:
                client.set_project_id(project["id"])
        else:
            project = list_and_select_projects(repo_manager, client, False, True)

        template_api = DatasetTemplateApi(aoa_client=client, show_archived=False)
        print_underscored("List of dataset templates for project {}:".format(project["name"]))
        if not len(template_api) > 0:
            print("No dataset templates were found")
        for i, template in enumerate(template_api):
            print("[{}] ({}) {}".format(i, template["id"], template["name"]))

    if args.datasets:
        project = get_current_project(repo_manager)
        if project:
            if client:
                client.set_project_id(project["id"])
        else:
            project = list_and_select_projects(repo_manager, client, False, True)

        dataset_api = DatasetApi(aoa_client=client, show_archived=False)
        print_underscored("List of datasets for project {}:".format(project["name"]))
        if not len(dataset_api) > 0:
            print("No datasets were found")
        for i, dataset in enumerate(dataset_api):
            print("[{}] ({}) {}".format(i, dataset["id"], dataset["name"]))

    if args.connections:
        from argparse import Namespace
        list_connections(Namespace(cwd=None,connections=None), repo_manager)

    elif not args.projects and not args.models and not args.local_models and not args.templates and not args.datasets and not args.connections:
        logging.error("Invalid object selection...")
        kwargs.get("parent_parser").print_help()
        exit(1)


def clone(args, repo_manager, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    if not args.project_id:
        project = list_and_select_projects(repo_manager, None, False, True)
    else:
        from aoa import AoaClient
        from aoa import ProjectApi

        client = AoaClient()
        project_api = ProjectApi(aoa_client=client, show_archived=False)
        try:
            project = project_api.find_by_id(args.project_id)
        except Exception:
            project = None
        if not project:
            print('Project not found. Please select one from the list below or press Ctrl+C to quit.')
            project = list_and_select_projects(repo_manager, None, False, True)

    if args.path:
        path = args.path
    else:
        repo_name = project["gitRepositoryUrl"].split('.git')[0].split('/')[-1]
        path = os.path.join(base_path, repo_name)

    project_git = project["gitRepositoryUrl"]

    repo_manager.clone_repository(project_git, path, project.get("branch", "master"))

    config = {
        "project_id": project['id']
    }
    repo_manager.write_repo_config(config, path)
    print("Project cloned at {}".format(path))


def configure(args, repo_manager, **kwargs):
    from base64 import b64encode
    import yaml

    if args.cwd:
        set_cwd(args.cwd)

    def _select_auth_mode():
        print("")
        auth_modes = ["basic", "oauth"]
        selected_mode = input_select('authentication mode', auth_modes, 'Supported authentication modes: ', 'basic')
        return selected_mode

    if args.repo:
        configure_repo(repo_manager)
        print("Project configured.")

    else:
        aoa_url = input_string('API endpoint', True)
        auth_mode = _select_auth_mode()

        config = {
            "aoa_url": aoa_url,
            "auth_mode": auth_mode
        }

        if auth_mode == "basic":
            print("")
            auth_user = input_string('username', True)
            auth_pass = input_string('password', True, '', True, True if args.test else False)

            config["auth_credentials"] = b64encode("{}:{}".format(auth_user, auth_pass).encode()).decode()

        # we don't support oauth_client_credentials via configure as its only meant for machine to machine auth which
        # wouldn't be using an interactive cli anyway
        if auth_mode == "oauth":
            print("")
            config["auth_client_id"] = input_string('client_id', True)
            config["auth_client_secret"] = input_string('client_secret', True, '', True, True if args.test else False)
            config["auth_client_token_url"] = input_string('token_url', True)
            config["auth_client_refresh_token"] = input_string('refresh_token', True)

        Path(config_dir).mkdir(parents=True, exist_ok=True)
        config_file = "{}/config.yaml".format(config_dir)

        with open(config_file, 'w+') as f:
            yaml.safe_dump(config, f, default_flow_style=False)
            print("")
            print("New configuration saved at {}".format(config_file))


def message(args, repo_manager, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    from aoa import AoaClient
    from aoa import MessageApi
    import json

    client = AoaClient()
    message_api = MessageApi(client)

    event = json.loads(args.jobevent)

    if args.jobevent:
        message_api.send_job_event(event)
    elif args.jobprogress:
        message_api.send_progress_event(event)
    else:
        logging.error("Invalid message type...")
        kwargs.get("parent_parser").print_help()
        exit(1)


def list_connections(args, repo_manager, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    try:
        Path(config_dir).mkdir(parents=True, exist_ok=True)
        connections_file = "{}/connections.yaml".format(config_dir)
        with open(connections_file, 'r') as f:
            connections = yaml.safe_load(f)
    except:
        logging.error("No connections file found. Please add them first.")
        exit(1)

    if 'connections' in connections and len(connections['connections']) > 0:
        print_underscored("List of local connections:")
        for i, connection in enumerate(connections['connections']):
            print("[{}] ({}) Name: {} Username: {} Host: {} Database: {}".format(i, connection["id"], connection["name"], connection["username"], connection["host"], connection.get("database", "")))
    else:
        logging.error("No connections found in file. Please add them first.")


def add_connections(args, repo_manager, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    try:
        Path(config_dir).mkdir(parents=True, exist_ok=True)
        connections_file = "{}/connections.yaml".format(config_dir)
        with open(connections_file, 'r') as f:
            connections_dict = yaml.safe_load(f)
            if connections_dict is None:
                connections_dict = {}
    except:
        logging.info("No connections file found, creating new one...")
        connections_dict = {}

    connections = connections_dict['connections'] if connections_dict and 'connections' in connections_dict and len(connections_dict['connections']) > 0 else []
    if args.name and args.username and args.password and args.host:
        name = args.name
        username = args.username
        password = args.password
        host = args.host
        database = args.database
    elif not (args.username or args.password or args.name or args.host):
        name = input_string('name', True, 'Type the connection name')
        username = input_string('username', True, 'Type the connection username')
        password = input_string('password', True, 'Type the connection password (will not show)', password = True)
        host = input_string('host', True, 'Type the connection host')
        database = input_string('database', False, 'Type the default database (optional)')
    else:
        logging.error("Invalid arguments...")
        args.parent_parser.print_help()
        exit(1)

    import uuid
    connection_id = str(uuid.uuid4())
    encrypted_password = crypto.td_encrypt_password(password, "{}/{}.key".format(config_dir, connection_id), "{}/{}.pass".format(config_dir, connection_id))
    connections.append({'id': connection_id, 'name': name, 'username': username, 'password': encrypted_password, 'host': host, 'logmech': 'TDNEGO', 'database': database})
    connections_dict['connections'] = connections

    try:
        connections_file = "{}/connections.yaml".format(config_dir)
        with open(connections_file, 'w+') as f:
            yaml.safe_dump(connections_dict, f, default_flow_style=False)
    except Exception as ex:
        logging.error("Can't save connections file: {}".format(ex))
        exit(1)


def remove_connections(args, repo_manager, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    def _check_connection_exists(id, obj):
        for connection in obj:
            if 'id' in connection and connection['id'] == id:
                return True
        return False

    def _remove_connection(id, obj):
        if os.path.exists("{}/{}.key".format(config_dir, id)):
            os.remove("{}/{}.key".format(config_dir, id))
        if os.path.exists("{}/{}.pass".format(config_dir, id)):
            os.remove("{}/{}.pass".format(config_dir, id))
        result = []
        for item in obj:
            if item['id'] != id:
                result.append(item)
        return result

    try:
        Path(config_dir).mkdir(parents=True, exist_ok=True)
        connections_file = "{}/connections.yaml".format(config_dir)
        with open(connections_file, 'r') as f:
            connections_dict = yaml.safe_load(f)
            if connections_dict is None or not ('connections' in connections_dict and len(connections_dict['connections']) > 0):
                logging.info("No connections defined, nothing to remove.")
    except:
        logging.info("No connections file found, nothing to remove.")
        exit(0)

    connections = connections_dict['connections'] if 'connections' in connections_dict and len(connections_dict['connections']) > 0 else []
    if args.connection:
        id = args.connection
        if not _check_connection_exists(id, connections):
            logging.info('Connection does not exists, exiting.')
            exit(1)
    else:
        name = input_select('connection', [item['name'] for item in connections], 'List of local connections')
        connection = next((connections[i]['id'] for i, item in enumerate(connections) if connections[i]['name'] == name), None)

    connections = _remove_connection(connection, connections)
    connections_dict['connections'] = connections

    try:
        connections_file = "{}/connections.yaml".format(config_dir)
        with open(connections_file, 'w+') as f:
            yaml.safe_dump(connections_dict, f, default_flow_style=False)
    except Exception as ex:
        logging.error("Can't save connections file: {}".format(ex))
        exit(1)


def export_connection(args, repo_manager, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    def _check_connection_exists(id, obj):
        for connection in obj:
            if 'id' in connection and connection['id'] == id:
                return True
        return False

    try:
        Path(config_dir).mkdir(parents=True, exist_ok=True)
        connections_file = "{}/connections.yaml".format(config_dir)
        with open(connections_file, 'r') as f:
            connections_dict = yaml.safe_load(f)
            if connections_dict is None or not ('connections' in connections_dict and len(connections_dict['connections']) > 0):
                logging.info("No connections defined, Please add them first.")
    except:
        logging.info("No connections file found, Please add them first.")
        exit(0)

    connections = connections_dict['connections'] if 'connections' in connections_dict and len(connections_dict['connections']) > 0 else []
    if args.connection:
        id = args.connection
        if not _check_connection_exists(id, connections):
            logging.info('Connection does not exists, exiting.')
            exit(1)
    else:
        name = input_select('connection', [item['name'] for item in connections], 'List of local connections')
        connection = next((connections[i] for i, item in enumerate(connections) if connections[i]['name'] == name), None)

    print("\nCopy the below command and execute in your terminal: \n")
    print("export AOA_CONN_USERNAME=\"{}\" && \\\n".format(bash_escape(connection['username'])) +
          "export AOA_CONN_PASSWORD=\"{}\" && \\\n".format(bash_escape(connection['password'])) +
          "export AOA_CONN_HOST=\"{}\" && \\\n".format(bash_escape(connection['host'])) +
          "export AOA_CONN_LOG_MECH=\"{}\" && \\\n".format(bash_escape(connection['logmech'])) +
          "export AOA_CONN_DATABASE=\"{}\" && \\\n".format(bash_escape(connection.get('database', ""))))


def activate_connection(args, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    def _check_connection_exists(id, obj):
        for connection in obj:
            if 'id' in connection and connection['id'] == id:
                return True
        return False

    def _activate_connection(id, obj):
        for connx in obj:
            if 'id' in connx and connx['id'] == id:
                os.environ['AOA_CONN_USERNAME'] = connx['username']
                os.environ['AOA_CONN_PASSWORD'] = connx['password']
                os.environ['AOA_CONN_HOST'] = connx['host']
                os.environ['AOA_CONN_LOG_MECH'] = connx['logmech']
                os.environ['AOA_CONN_DATABASE'] = connx.get('database', "")
                return True
        return False

    try:
        Path(config_dir).mkdir(parents=True, exist_ok=True)
        connections_file = "{}/connections.yaml".format(config_dir)
        with open(connections_file, 'r') as f:
            connections_dict = yaml.safe_load(f)
            if connections_dict is None:
                connections_dict = {}
    except:
        logging.info("No connections file found, you must create a connection first")
        exit(1)

    connections = connections_dict['connections'] if 'connections' in connections_dict and len(
        connections_dict['connections']) > 0 else []
    if args.connection:
        connection = args.connection
    elif kwargs.get('connection'):
        connection = kwargs.get('connection')
    else:
        if len(connections) == 1:
            print("Automatic connection selection as only one available: {}".format(connections[0]["name"]))
            connection = connections[0]["id"]
        else:
            selected_connection_value = input_select('connection', [item['name'] for item in connections],
                                                     'Available connections:')
            connection = next((connections[i]['id'] for i, item in enumerate(connections) if
                               connections[i]['name'] == selected_connection_value), None)

    if _check_connection_exists(connection, connections):
        _activate_connection(connection, connections)
    else:
        logging.error('The specified connection was not found.')
        exit(1)

    return connection


def test_connection(args, repo_manager, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    def _test_connection():
        from aoa.util import aoa_create_context
        from teradataml import get_connection
        aoa_create_context()
        result = get_connection().execute("SEL infodata FROM dbc.dbcinfo WHERE infokey = 'VERSION'").fetchall()
        return result[0][0]

    from argparse import Namespace
    activate_connection(Namespace(cwd=None, connection=args.connection if args.connection else None))

    try:
        ver = _test_connection()
        logging.info("Connection successful, Vantage version: {}".format(ver))
    except Exception as ex:
        logging.error("Failed to test connecton: {}".format(ex))
        logging.error("Please use 'aoa connection add' to add working connection")
        exit(1)


def compute_stats(args, repo_manager, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)

    from argparse import Namespace
    activate_connection(Namespace(cwd=None, connection=None))

    from aoa.util import aoa_create_context
    from aoa.stats import stats, store
    from teradataml import DataFrame
    try:
        aoa_create_context()
        if args.feature_type == 'categorical':
            fs = stats.compute_categorical_stats(DataFrame.from_query(f"SEL * FROM {args.source_table}"), args.columns.split(","))
            store.save_feature_stats(features_table=args.metadata_table, stats=fs)
        elif args.feature_type == 'continuous':
            fs = stats.compute_continuous_stats(DataFrame.from_query(f"SEL * FROM  {args.source_table}"), args.columns.split(","))
            store.save_feature_stats(features_table=args.metadata_table, stats=fs)
    except Exception as ex:
        logging.exception("Could not compute feature stats")
        exit(1)


def list_stats(args, repo_manager, **kwargs):
    if args.cwd:
        set_cwd(args.cwd)
    
    from argparse import Namespace
    activate_connection(Namespace(cwd=None, connection=None))
    from aoa.util import aoa_create_context
    from aoa.stats import store
    try:
        aoa_create_context()
        print(store.get_feature_stats(args.metadata_table))
    except Exception as ex:
        logging.exception("Could not compute feature stats: {}".format(ex))
        exit(1)


def doctor(args, repo_manager, **kwargs):
    logging.info("Testing ModelOps service configuration")
    
    from aoa import AoaClient
    from aoa import ProjectApi

    try:
        client = AoaClient()
        project_api = ProjectApi(aoa_client=client, show_archived=False)
        projects = list(project_api)
        if len(projects) > 0:
            logging.info("ModelOps service configured, number of projects: {}".format(len(projects)))
        else:
            raise
    except Exception as ex:
        logging.exception("Failed to connect to ModelOps or find any project, use 'aoa configure' to update ModelOps configuraion")

    print("")

    logging.info("Testing Vantage connections")
    args.connection = None
    test_connection(args, repo_manager)


def set_cwd(path):
    import os
    if not os.path.exists(path):
        logging.error("The specified path does not exist... exiting")
        exit(1)
    os.chdir(path)
    global base_path, repo_template_catalog, model_catalog
    base_path = os.path.abspath(os.getcwd())
    repo_template_catalog = os.path.join(base_path, "model_templates/")
    model_catalog = os.path.join(base_path, "model_definitions/")


def print_help(args, **kwargs):
    from aoa import __version__

    if args.version:
        print("{}".format(__version__))
    else:
        kwargs.get("parent_parser").print_help()


def print_underscored(message):
    print(message)
    print('-' * len(message))


def list_and_select_projects(repo_manager, client=None, as_list=False, check_config=False):
    from aoa import ProjectApi

    project_api = ProjectApi(aoa_client=client, show_archived=False)
    projects = list(project_api)

    print_underscored('List of projects:' if as_list else 'Available projects:')
    if not len(projects) > 0:
        print('No projects were found')
    for ix, item in enumerate(projects):
        print("[{}] ({}) {}".format(ix, item["id"], item["name"]))
    if as_list:
        return

    current_project = get_current_project(repo_manager, check_config)
    current_project = current_project if current_project else ""
    current_index = "none"
    for ix, item in enumerate(projects):
        if 'id' in current_project and current_project["id"] == item["id"]:
            current_index = ix
    tmp_index = input("Select project by index (current selection: {}): ".format(current_index))
    print("")

    if ((not tmp_index.isnumeric() or int(tmp_index) >= len(projects)) and not tmp_index == '') or (
            tmp_index == '' and current_index == "none"):
        print('Wrong selection, please try again by selecting the index number on the first column.')
        print("")
        return list_and_select_projects(repo_manager, client, as_list, check_config)

    if tmp_index == '':
        tmp_index = current_index

    selected_project = projects[int(tmp_index)]

    if client:
        client.set_project_id(selected_project["id"])

    return selected_project


def get_current_project(repo_manager, check_repo_conf=False):
    from aoa import AoaClient
    from aoa import ProjectApi

    if not validate_cwd_valid():
        return

    repo_conf = repo_manager.read_repo_config()
    current_project_id = repo_conf["project_id"] if repo_conf and "project_id" in repo_conf else ""
    client = AoaClient()
    project_api = ProjectApi(aoa_client=client, show_archived=False)
    current_project = project_api.find_by_id(current_project_id)

    if current_project:
        return current_project
    elif not current_project and check_repo_conf:
        print('The repository is not properly configured.')
        print('Please execute \'aoa configure --repo\'')
        print("")

    return


def configure_repo(repo_manager):
    current_project = list_and_select_projects(repo_manager, None, False, False)
    config = {
        "project_id": current_project['id']
    }
    repo_manager.write_repo_config(config)


def main():
    try:
        parent_parser = argparse.ArgumentParser(description='AOA CLI')
        parent_parser.add_argument('--debug', action='store_true', help='Enable debug logging')
        parent_parser.add_argument('--version', action='store_true', help='Display the version of this tool')
        parent_parser.set_defaults(func=print_help)

        subparsers = parent_parser.add_subparsers(title="actions", description="valid actions", dest="command")

        common_parser = argparse.ArgumentParser(add_help=False)
        common_parser.add_argument('--debug', action='store_true', help='Enable debug logging')
        common_parser.add_argument('-cwd', '--cwd', type=str, help=argparse.SUPPRESS)
        common_parser.add_argument('--test', action='store_true', help=argparse.SUPPRESS)

        parser_list = subparsers.add_parser("list", help="List projects, models, local models or datasets", parents=[common_parser])
        parser_list.add_argument('-p', '--projects', action='store_true', help='List projects')
        parser_list.add_argument('-m', '--models', action='store_true', help='List registered models (committed / pushed)')
        parser_list.add_argument('-lm', '--local-models', action='store_true', help='List local models. Includes registered and non-registered (non-committed / non-pushed)')
        parser_list.add_argument('-t', '--templates', action='store_true', help='List dataset templates')
        parser_list.add_argument('-d', '--datasets', action='store_true', help='List datasets')
        parser_list.add_argument('-c', '--connections', action='store_true', help='List local connections')
        parser_list.set_defaults(func=list_resources)

        parser_add = subparsers.add_parser("add", help="Add model to working dir", parents=[common_parser])
        parser_add.add_argument('-t', '--template-url', type=str, help='Git URL for template repository')
        parser_add.add_argument('-b', '--branch', type=str, default='main', help='Git branch to pull templates')
        parser_add.set_defaults(func=add_model)

        parser_run = subparsers.add_parser("run", help="Train and Evaluate model locally", parents=[common_parser])
        parser_run.add_argument('-id', '--model-id', type=str, help='Id of model')
        parser_run.add_argument('-m', '--mode', type=str, help='Mode (train or evaluate)')
        parser_run.add_argument('-d', '--dataset-id', type=str, help='Remote datasetId')
        parser_run.add_argument('-t', '--dataset-template-id', type=str, help='Remote datasetTemplateId')
        parser_run.add_argument('-ld', '--local-dataset', type=str, help='Path to local dataset metadata file')
        parser_run.add_argument('-lt', '--local-dataset-template', type=str, help='Path to local dataset template metadata file')
        parser_run.add_argument('-c', '--connection', type=str, help='Local connection id')
        parser_run.set_defaults(func=run_model)

        parser_init = subparsers.add_parser("init", help="Initialize model directory with basic structure", parents=[common_parser])
        parser_init.set_defaults(func=init_model_directory)

        parser_clone = subparsers.add_parser("clone", help="Clone Project Repository", parents=[common_parser])
        parser_clone.add_argument('-id', '--project-id', type=str, help='Id of Project to clone')
        parser_clone.add_argument('-p', '--path', type=str, help='Path to clone repository to')
        parser_clone.set_defaults(func=clone)

        parser_configure = subparsers.add_parser("configure", help="Configure AOA client", parents=[common_parser])
        parser_configure.add_argument('--repo', action='store_true', help='Configure the repo only')
        parser_configure.set_defaults(func=configure)

        parser_message = subparsers.add_parser("message", help="Send a message to AOA message broker", parents=[common_parser])
        parser_message.add_argument('-j', '--jobevent', type=str, help='Send jobevent message')
        parser_message.add_argument('-p', '--jobprogress', type=str, help='Send jobprogress message')
        parser_message.set_defaults(func=message)

        parser_connections = subparsers.add_parser("connection", help="Manage local connections")
        subparser_connections = parser_connections.add_subparsers(title="actions", description="valid actions", dest="command")
        subparser_list_connections = subparser_connections.add_parser("list", help="List all local connections", parents=[common_parser])
        subparser_list_connections.set_defaults(func=list_connections)
        subparser_add_connections = subparser_connections.add_parser("add", help="Add a local connection", parents=[common_parser])
        subparser_add_connections.add_argument('-n', '--name', type=str, help='Connection name')
        subparser_add_connections.add_argument('-H', '--host', type=str, help='Connection host')
        subparser_add_connections.add_argument('-u', '--username', type=str, help='Connection username')
        subparser_add_connections.add_argument('-p', '--password', type=str, help='Connection password')
        subparser_add_connections.add_argument('-d', '--database', type=str, help='Connection default database')
        subparser_add_connections.set_defaults(func=add_connections, parent_parser=subparser_add_connections)
        subparser_remove_connections = subparser_connections.add_parser("remove", help="Remove a local connection", parents=[common_parser])
        subparser_remove_connections.add_argument('-c', '--connection', type=str, help='Local connection id')
        subparser_remove_connections.set_defaults(func=remove_connections)
        subparser_export_connections = subparser_connections.add_parser("export", help="Export a local connection to be used as a shell script", parents=[common_parser])
        subparser_export_connections.add_argument('-c', '--connection', type=str, help='Local connection id')
        subparser_export_connections.set_defaults(func=export_connection)
        subparser_test_connections = subparser_connections.add_parser("test", help="Test a local connection", parents=[common_parser])
        subparser_test_connections.add_argument('-c', '--connection', type=str, help='Local connection id')
        subparser_test_connections.set_defaults(func=test_connection)

        parser_features = subparsers.add_parser("features", help="Manage feature statistics")
        subparser_features = parser_features.add_subparsers(title="action", description="valid actions", dest="command")
        subparser_compute_stats = subparser_features.add_parser("compute-stats", help="Compute feature statistics", parents=[common_parser])
        subparser_compute_stats.add_argument('-s', '--source-table', type=str, help='Feature source table/view')
        subparser_compute_stats.add_argument('-m', '--metadata-table', type=str, help='Metadata table for feature stats, including database name')
        subparser_compute_stats.add_argument('-t', '--feature-type', choices=['continuous', 'categorical'], default='continuous', help='Feature type: continuous or categorical')
        subparser_compute_stats.add_argument('-c', '--columns', type=str, help='List of feature columns')
        subparser_compute_stats.set_defaults(func=compute_stats, parent_parser=subparser_compute_stats)
        subparser_list_stats = subparser_features.add_parser("list-stats", help="List available statistics", parents=[common_parser])
        subparser_list_stats.add_argument('-m', '--metadata-table', type=str, help='Metadata table for feature stats, including database name')
        subparser_list_stats.set_defaults(func=list_stats, parent_parser=subparser_list_stats)

        parser_doctor = subparsers.add_parser("doctor", help="Diagnose configuration issues", parents=[common_parser])
        parser_doctor.set_defaults(func=doctor)

        argcomplete.autocomplete(parent_parser)

        args = parent_parser.parse_args()

        logging_level = logging.DEBUG if args.debug else logging.INFO
        logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging_level)
        print("")

        from aoa import RepoManager
        repo_manager = RepoManager(base_path, repo_template_catalog)

        args.func(repo_manager=repo_manager, args=args, parent_parser=parent_parser)

    except requests.exceptions.ConnectionError as ce:
        logging.error("Please check that the api service is running and the client")
        logging.error("is properly configured by executing 'aoa configure'.\n")
        logging.error("{}".format(ce))
        exit(1)

    except KeyboardInterrupt:
        logging.info("")
        logging.info("Keyboard interrupt... exiting")
        exit(1)

    except Exception as ex:
        if args.debug:
            logging.info("An error occurred, printing stack trace output: ")
            raise
        else:
            logging.error("An error occurred: {}".format(ex))
            exit(1)


if __name__ in ["__main__", "aoa"]:
    main()
