#!/usr/bin/env python3

import argparse
import os
import platform
import random
import shutil
import string
import subprocess
import sys
import time
import yaml

import distro

base_directory = os.path.join(os.environ["HOME"], ".aqueduct")
server_directory = os.path.join(os.environ["HOME"], ".aqueduct", "server")
ui_directory = os.path.join(os.environ["HOME"], ".aqueduct", "ui")

package_version = "0.0.2"

welcome_message = """
***************************************************
Your API Key: %s

The Web UI and the backend server are accessible at: http://%s:8080
***************************************************
"""

def update_config_yaml(file):
    s=string.ascii_uppercase+string.digits
    encryption_key = ''.join(random.sample(s,32))
    api_key = ''.join(random.sample(s,32))

    with open(file, "r") as sources:
        lines = sources.readlines()
    with open(file, "w") as sources:
        for line in lines:
            if "<BASE_PATH>" in line:
                sources.write(line.replace("<BASE_PATH>", server_directory))
            elif "<ENCRYPTION_KEY>" in line:
                sources.write(line.replace("<ENCRYPTION_KEY>", encryption_key))
            elif "<API_KEY>" in line:
                sources.write(line.replace("<API_KEY>", api_key))
            else:
                sources.write(line)

def execute_command(args, cwd=None):
    with subprocess.Popen(args, stdout=sys.stdout, stderr=sys.stderr, cwd=cwd) as proc:
        proc.communicate()
        if proc.returncode != 0:
            raise Exception("Error executing command: %s" % args)

def execute_command_nonblocking(args, cwd=None):
    return subprocess.Popen(args, stdout=sys.stdout, stderr=sys.stderr, cwd=cwd)

def update_executable_permissions():
    execute_command(["chmod", "755", os.path.join(server_directory, "bin", "server")])
    execute_command(["chmod", "755", os.path.join(server_directory, "bin", "executor")])
    execute_command(["chmod", "755", os.path.join(server_directory, "bin", "migrator")])

def generate_version_file(file_path):
    with open(file_path, 'w') as f:
        f.write(package_version)

# Returns a bool indicating whether we need to perform a version upgrade.
def require_update(file_path):
    if not os.path.isfile(file_path):
        return True
    with open(file_path, 'r') as f:
        current_version = f.read()
        if package_version < current_version:
            raise Exception("Attempting to install an older version %s but found existing newer version %s" % (package_version, current_version))
        elif package_version == current_version:
            return False
        else:
            return True

def update_ui_version():
    print("Updating UI version to %s" % package_version)
    try:
        execute_command(["rm", "-rf", ui_directory])
        os.mkdir(ui_directory)
        generate_version_file(os.path.join(ui_directory, "__version__"))
        s3_prefix = "https://aqueduct-ai.s3.us-east-2.amazonaws.com/assets/%s/ui" % package_version
        execute_command(["curl", os.path.join(s3_prefix, "ui.zip"), "--output", os.path.join(ui_directory, "ui.zip")])
        execute_command(["unzip", os.path.join(ui_directory, "ui.zip"), "-d", ui_directory])
        execute_command(["rm", os.path.join(ui_directory, "ui.zip")])
    except Exception as e:
        print(e)
        execute_command(["rm", "-rf", ui_directory])
        exit(1)

def download_golang_binaries(s3_prefix):
    system = platform.system()
    arch = platform.machine()
    if system == "Linux" and arch == "x86_64":
        print("Operating system is Linux with architecture amd64.")
        execute_command(["curl", os.path.join(s3_prefix, "bin/linux_amd64/server"), "--output", os.path.join(server_directory, "bin/server")])
        execute_command(["curl", os.path.join(s3_prefix, "bin/linux_amd64/executor"), "--output", os.path.join(server_directory, "bin/executor")])
        execute_command(["curl", os.path.join(s3_prefix, "bin/linux_amd64/migrator"), "--output", os.path.join(server_directory, "bin/migrator")])
    elif system == "Darwin" and arch == "x86_64":
        print("Operating system is Mac with architecture amd64.")
        execute_command(["curl", os.path.join(s3_prefix, "bin/darwin_amd64/server"), "--output", os.path.join(server_directory, "bin/server")])
        execute_command(["curl", os.path.join(s3_prefix, "bin/darwin_amd64/executor"), "--output", os.path.join(server_directory, "bin/executor")])
        execute_command(["curl", os.path.join(s3_prefix, "bin/darwin_amd64/migrator"), "--output", os.path.join(server_directory, "bin/migrator")])
    elif system == "Darwin" and arch == "arm64":
        print("Operating system is Mac with architecture arm64.")
        execute_command(["curl", os.path.join(s3_prefix, "bin/darwin_arm64/server"), "--output", os.path.join(server_directory, "bin/server")])
        execute_command(["curl", os.path.join(s3_prefix, "bin/darwin_arm64/executor"), "--output", os.path.join(server_directory, "bin/executor")])
        execute_command(["curl", os.path.join(s3_prefix, "bin/darwin_arm64/migrator"), "--output", os.path.join(server_directory, "bin/migrator")])
    else:
        raise Exception("Unsupported operating system and architecture combination: %s, %s" % (system, arch))

def update_server_version():
    print("Updating server version to %s" % package_version)
    if os.path.isfile(os.path.join(server_directory, "__version__")):
        execute_command(["rm", os.path.join(server_directory, "__version__")])
    generate_version_file(os.path.join(server_directory, "__version__"))

    s3_prefix = "https://aqueduct-ai.s3.us-east-2.amazonaws.com/assets/%s/server" % package_version
    download_golang_binaries(s3_prefix)
    update_executable_permissions()

    execute_command(["curl", os.path.join(s3_prefix, "bin/start-function-executor.sh"), "--output", os.path.join(server_directory, "bin/start-function-executor.sh")])
    execute_command(["curl", os.path.join(s3_prefix, "bin/install_sqlserver_ubuntu.sh"), "--output", os.path.join(server_directory, "bin/install_sqlserver_ubuntu.sh")])

    execute_command([os.path.join(server_directory, "bin", "migrator"), "--type", "sqlite", "goto", "8"])

def generate_welcome_message(expose):
    if not expose:
        expose_ip = "localhost"
    else:
        expose_ip = "<IP_ADDRESS>"
    apikey = get_apikey()
    return welcome_message % (apikey, expose_ip)

def start(expose):
    if not os.path.isdir(ui_directory):
        update_ui_version()

    if require_update(os.path.join(ui_directory, "__version__")):
        update_ui_version()
    
    if not os.path.isdir(server_directory):
        try:
            directories = [
                server_directory,
                os.path.join(server_directory, "db"),
                os.path.join(server_directory, "storage"),
                os.path.join(server_directory, "storage", "operators"),
                os.path.join(server_directory, "vault"),
                os.path.join(server_directory, "bin"),
                os.path.join(server_directory, "config"),
                os.path.join(server_directory, "logs"),
            ]

            for directory in directories:
                os.mkdir(directory)

            update_server_version()

            s3_prefix = "https://aqueduct-ai.s3.us-east-2.amazonaws.com/assets/%s/server" % package_version
            execute_command(["curl", os.path.join(s3_prefix, "config/config.yml"), "--output", os.path.join(server_directory, "config/config.yml")])
            update_config_yaml(os.path.join(server_directory, "config", "config.yml"))
            execute_command(["curl", os.path.join(s3_prefix, "db/demo.db"), "--output", os.path.join(server_directory, "db/demo.db")])

            print("Finished initializing Aqueduct base directory.")
        except Exception as e:
            print(e)
            execute_command(["rm", "-rf", server_directory])
            exit(1)

    if require_update(os.path.join(server_directory, "__version__")):
        try:
            update_server_version()
        except Exception as e:
            print(e)
            if os.path.isfile(os.path.join(server_directory, "__version__")):
                execute_command(["rm", os.path.join(server_directory, "__version__")])
            exit(1)

    if expose:
        popen_handle = execute_command_nonblocking([os.path.join(server_directory, "bin", "server"), "--config", os.path.join(server_directory, "config", "config.yml"), "--expose"], ui_directory)
    else:
        popen_handle = execute_command_nonblocking([os.path.join(server_directory, "bin", "server"), "--config", os.path.join(server_directory, "config", "config.yml")], ui_directory)
    return popen_handle

def install_postgres():
    execute_command([sys.executable, "-m", "pip", "install", "psycopg2-binary"])

def install_bigquery():
    execute_command([sys.executable, "-m", "pip", "install", "pandas-gbq"])

def install_snowflake():
    execute_command([sys.executable, "-m", "pip", "install", "snowflake-sqlalchemy"])

def install_s3():
    execute_command([sys.executable, "-m", "pip", "install", "pyarrow"])

def install_mysql():
    system = platform.system()
    if system == "Linux":
        if distro.id() == "ubuntu" or distro.id() == "debian":
            execute_command(["sudo", "apt-get", "install", "-y", "python3-dev", "default-libmysqlclient-dev", "build-essential"])
        elif distro.id() == "centos" or distro.id() == "rhel":
            execute_command(["sudo", "yum", "install", "-y", "python3-devel", "mysql-devel"])
        else:
            print("Unsupported distribution:", distro.id())
    elif system == "Darwin":
        execute_command(["brew", "install", "mysql"])
    else:
        print("Unsupported operating system:", system)
    
    execute_command([sys.executable, "-m", "pip", "install", "mysqlclient"])

def install_sqlserver():
    system = platform.system()
    if system == "Linux":
        if distro.id() == "ubuntu":
            execute_command(["bash", os.path.join(server_directory, "bin", "install_sqlserver_ubuntu.sh")])
        else:
            print("Unsupported distribution:", distro.id())
    elif system == "Darwin":
        execute_command(["brew", "tap", "microsoft/mssql-release", "https://github.com/Microsoft/homebrew-mssql-release"])
        execute_command(["brew", "update"])
        execute_command(["HOMEBREW_NO_ENV_FILTERING=1", "ACCEPT_EULA=Y", "brew", "install", "msodbcsql17", "mssql-tools"])
    else:
        print("Unsupported operating system:", system)
    
    execute_command([sys.executable, "-m", "pip", "install", "pyodbc"])

def install(system):
    if system == "postgres":
        install_postgres()
    elif system == "bigquery":
        install_bigquery()
    elif system == "snowflake":
        install_snowflake()
    elif system == "s3":
        install_s3()
    elif system == "mysql":
        install_mysql()
    elif system == "sqlserver":
        install_sqlserver()
    else:
        raise Exception("Unsupported system: %s" % system)

def ui(expose_ip):
    if not os.path.isdir(ui_directory):
        update_ui_version()

    if require_update(os.path.join(ui_directory, "__version__")):
        update_ui_version()

    if expose_ip:
        server_ip = expose_ip
        print("The Aqueduct server at " + server_ip + " will be externally accessible.")
    else:
        server_ip = "localhost"

    generate_env_local(server_ip)
    return execute_command_nonblocking(["npm", "run", "start"], os.path.join(ui_directory, "app"))

def get_apikey():
    config_file = os.path.join(server_directory, "config", "config.yml")
    with open(config_file, "r") as f:
        try:
            return yaml.safe_load(f)['apiKey']
        except yaml.YAMLError as exc:
            print(exc)
            exit(1)

def apikey():
    print(get_apikey())

def clear():
    execute_command(["rm", "-rf", base_directory])

def version():
    print(package_version)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='The Aqueduct CLI')
    subparsers = parser.add_subparsers(dest="command")
    
    start_args = subparsers.add_parser('start', help=
                               '''This starts the Aqueduct server and the UI in a blocking
                               fashion. To background the process run aqueduct start &.

                               Add --expose <IP_ADDRESS> to access the Aqueduct service from
                               an external server, where <IP_ADDRESS> is the
                               public IP of the server running the Aqueduct service.
                               ''')
    start_args.add_argument('--expose', default=False, action='store_true',
                    help="Use this option to expose the server to the public.")
    server_args = subparsers.add_parser('server', help=
                                   '''This starts the Aqueduct server in a
                                   blocking fashion. To background the process,
                                   run aqueduct server &.''')
    server_args.add_argument('--expose', default=False, action='store_true',
                    help="Use this option to expose the server to the public.")
    ui_args = subparsers.add_parser('ui', help=
                               '''This starts the Aqueduct UI in a blocking
                               fashion. To background the process run aqueduct
                               ui &.

                               Add --expose <IP_ADDRESS> to access the UI from
                               an external server, where <IP_ADDRESS> is the
                               public IP of the server you are running on.
                               ''')
    ui_args.add_argument('--expose', dest="expose", 
                    help="The IP address of the server running Aqueduct.") 

    install_args = subparsers.add_parser('install', help=
                             '''Install the required library dependencies for
                             an Aqueduct connector to a third-party system.''')
    install_args.add_argument('system', nargs=1, help="Supported integrations: postgres, mysql, sqlserver, s3, snowflake, bigquery.")

    apikey_args = subparsers.add_parser('apikey', help="Display your Aqueduct API key.")
    clear_args = subparsers.add_parser('clear', help="Erase your Aqueduct installation.")
    version_args = subparsers.add_parser('version', help="Retrieve the package version number.")

    args = parser.parse_args()

    if not os.path.isdir(base_directory):
        os.mkdir(base_directory)
    
    if args.command == "start":
        try:
            popen_handle = start(args.expose)
            time.sleep(2)
            print(generate_welcome_message(args.expose))
            popen_handle.wait()
        except (Exception, KeyboardInterrupt) as e:
            print(e)
            print('\nTerminating Aqueduct service...')
            popen_handle.kill()
            print('Aqueduct service successfully terminated.')
    elif args.command == "server":
        print("aqueduct ui and aqueduct server have been deprecated; please use aqueduct start to run both the UI and backend servers")
    elif args.command == "install":
        install(args.system[0]) # argparse makes this an array so only pass in value [0].
    elif args.command == "ui":
        print("aqueduct ui and aqueduct server have been deprecated; please use aqueduct start to run both the UI and backend servers")
    elif args.command == "apikey":
        apikey()
    elif args.command == "clear":
        clear()
    elif args.command == "version":
        version()
    elif args.command is None:
        parser.print_help()
    else:
        print("Unsupported command:", args.command)
        sys.exit(1)
