#!/usr/bin/env python3

import argparse
import os
import platform
import random
import shutil
import socket
import string
import subprocess
import sys
import time
import zipfile

import distro
import requests
import yaml

SCHEMA_VERION = "12"

base_directory = os.path.join(os.environ["HOME"], ".aqueduct")
server_directory = os.path.join(os.environ["HOME"], ".aqueduct", "server")
ui_directory = os.path.join(os.environ["HOME"], ".aqueduct", "ui")

package_version = "0.0.5"
aws_credentials_path = os.path.join(os.environ["HOME"], ".aws", "credentials")

default_server_port = 8080

s3_server_prefix = (
    "https://aqueduct-ai.s3.us-east-2.amazonaws.com/assets/%s/server" % package_version
)
s3_ui_prefix = "https://aqueduct-ai.s3.us-east-2.amazonaws.com/assets/%s/ui" % package_version

welcome_message = """
***************************************************
Your API Key: %s

The Web UI and the backend server are accessible at: http://%s:%d
***************************************************
"""


def update_config_yaml(file):
    s = string.ascii_uppercase + string.digits
    encryption_key = "".join(random.sample(s, 32))
    api_key = "".join(random.sample(s, 32))

    with open(file, "r") as sources:
        lines = sources.readlines()
    with open(file, "w") as sources:
        for line in lines:
            if "<BASE_PATH>" in line:
                sources.write(line.replace("<BASE_PATH>", server_directory))
            elif "<ENCRYPTION_KEY>" in line:
                sources.write(line.replace("<ENCRYPTION_KEY>", encryption_key))
            elif "<API_KEY>" in line:
                sources.write(line.replace("<API_KEY>", api_key))
            else:
                sources.write(line)
    print("Updated configurations.")


def execute_command(args, cwd=None):
    with subprocess.Popen(args, stdout=sys.stdout, stderr=sys.stderr, cwd=cwd) as proc:
        proc.communicate()
        if proc.returncode != 0:
            raise Exception("Error executing command: %s" % args)


def generate_version_file(file_path):
    with open(file_path, "w") as f:
        f.write(package_version)


# Returns a bool indicating whether we need to perform a version upgrade.
def require_update(file_path):
    if not os.path.isfile(file_path):
        return True
    with open(file_path, "r") as f:
        current_version = f.read()
        if package_version < current_version:
            raise Exception(
                "Attempting to install an older version %s but found existing newer version %s"
                % (package_version, current_version)
            )
        elif package_version == current_version:
            return False
        else:
            return True


def update_executable_permissions():
    os.chmod(os.path.join(server_directory, "bin", "server"), 0o755)
    os.chmod(os.path.join(server_directory, "bin", "executor"), 0o755)
    os.chmod(os.path.join(server_directory, "bin", "migrator"), 0o755)


def download_server_binaries(architecture):
    with open(os.path.join(server_directory, "bin/server"), "wb") as f:
        f.write(requests.get(os.path.join(s3_server_prefix, f"bin/{architecture}/server")).content)
    with open(os.path.join(server_directory, "bin/executor"), "wb") as f:
        f.write(
            requests.get(os.path.join(s3_server_prefix, f"bin/{architecture}/executor")).content
        )
    with open(os.path.join(server_directory, "bin/migrator"), "wb") as f:
        f.write(
            requests.get(os.path.join(s3_server_prefix, f"bin/{architecture}/migrator")).content
        )
    with open(os.path.join(server_directory, "bin/start-function-executor.sh"), "wb") as f:
        f.write(
            requests.get(os.path.join(s3_server_prefix, "bin/start-function-executor.sh")).content
        )
    with open(os.path.join(server_directory, "bin/install_sqlserver_ubuntu.sh"), "wb") as f:
        f.write(
            requests.get(os.path.join(s3_server_prefix, "bin/install_sqlserver_ubuntu.sh")).content
        )
    print("Downloaded server binaries.")


def setup_server_binaries():
    print("Downloading server binaries.")
    server_bin_path = os.path.join(server_directory, "bin")
    shutil.rmtree(server_bin_path, ignore_errors=True)
    os.mkdir(server_bin_path)

    system = platform.system()
    arch = platform.machine()
    if system == "Linux" and arch == "x86_64":
        print("Operating system is Linux with architecture amd64.")
        download_server_binaries("linux_amd64")
    elif system == "Darwin" and arch == "x86_64":
        print("Operating system is Mac with architecture amd64.")
        download_server_binaries("darwin_amd64")
    elif system == "Darwin" and arch == "arm64":
        print("Operating system is Mac with architecture arm64.")
        download_server_binaries("darwin_arm64")
    else:
        raise Exception(
            "Unsupported operating system and architecture combination: %s, %s" % (system, arch)
        )


def update_ui_version():
    print("Updating UI version to %s" % package_version)
    try:
        shutil.rmtree(ui_directory, ignore_errors=True)
        os.mkdir(ui_directory)
        generate_version_file(os.path.join(ui_directory, "__version__"))
        ui_zip_path = os.path.join(ui_directory, "ui.zip")
        with open(ui_zip_path, "wb") as f:
            # We detect whether the server is running on a SageMaker instance by checking if the
            # directory /home/ec2-user/SageMaker exists. This is hacky but we couldn't find a better
            # solution at the moment.
            if os.path.isdir(os.path.join(os.sep, "home", "ec2-user", "SageMaker")):
                f.write(requests.get(os.path.join(s3_ui_prefix, "sagemaker", "ui.zip")).content)
            else:
                f.write(requests.get(os.path.join(s3_ui_prefix, "default", "ui.zip")).content)
        with zipfile.ZipFile(ui_zip_path, "r") as zip:
            zip.extractall(ui_directory)
        os.remove(ui_zip_path)
    except Exception as e:
        print(e)
        shutil.rmtree(ui_directory, ignore_errors=True)
        exit(1)


def update_server_version():
    print("Updating server version to %s" % package_version)

    version_file = os.path.join(server_directory, "__version__")
    if os.path.isfile(version_file):
        os.remove(version_file)
    generate_version_file(version_file)

    setup_server_binaries()
    update_executable_permissions()

    execute_command(
        [os.path.join(server_directory, "bin", "migrator"), "--type", "sqlite", "goto", SCHEMA_VERION]
    )


def update():
    if not os.path.isdir(base_directory):
        os.makedirs(base_directory)

    if not os.path.isdir(ui_directory) or require_update(os.path.join(ui_directory, "__version__")):
        update_ui_version()

    if not os.path.isdir(server_directory):
        try:
            directories = [
                server_directory,
                os.path.join(server_directory, "db"),
                os.path.join(server_directory, "storage"),
                os.path.join(server_directory, "storage", "operators"),
                os.path.join(server_directory, "vault"),
                os.path.join(server_directory, "bin"),
                os.path.join(server_directory, "config"),
                os.path.join(server_directory, "logs"),
            ]

            for directory in directories:
                os.mkdir(directory)

            update_server_version()

            with open(os.path.join(server_directory, "config/config.yml"), "wb") as f:
                f.write(requests.get(os.path.join(s3_server_prefix, "config/config.yml")).content)

            update_config_yaml(os.path.join(server_directory, "config", "config.yml"))

            with open(os.path.join(server_directory, "db/demo.db"), "wb") as f:
                f.write(requests.get(os.path.join(s3_server_prefix, "db/demo.db")).content)

            print("Finished initializing Aqueduct base directory.")
        except Exception as e:
            print(e)
            shutil.rmtree(server_directory, ignore_errors=True)
            exit(1)

    version_file = os.path.join(server_directory, "__version__")
    if require_update(version_file):
        try:
            update_server_version()
        except Exception as e:
            print(e)
            if os.path.isfile(version_file):
                os.remove(version_file)
            exit(1)

def execute_command(args, cwd=None):
    with subprocess.Popen(args, stdout=sys.stdout, stderr=sys.stderr, cwd=cwd) as proc:
        proc.communicate()
        if proc.returncode != 0:
            raise Exception("Error executing command: %s" % args)

def execute_command_nonblocking(args, cwd=None):
    return subprocess.Popen(args, stdout=sys.stdout, stderr=sys.stderr, cwd=cwd)

def generate_welcome_message(expose, port):
    if not expose:
        expose_ip = "localhost"
    else:
        expose_ip = "<IP_ADDRESS>"
    apikey = get_apikey()
    return welcome_message % (apikey, expose_ip, port)

def is_port_in_use(port: int) -> bool:
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex(("localhost", port)) == 0

def start(expose, port):
    update()

    if port is None:
        server_port = default_server_port
        while is_port_in_use(server_port):
            server_port += 1
        if not server_port == default_server_port:
            print("Default port %d is occupied. Next available port is %d" % (default_server_port, server_port))
    else:
        server_port = int(port)
        print("Server will use the user-specified port %d" % server_port)

    if expose:
        popen_handle = execute_command_nonblocking(
            [
                os.path.join(server_directory, "bin", "server"), 
                "--config", 
                os.path.join(server_directory, "config", "config.yml"), 
                "--expose",
                "--port",
                str(server_port),
            ], 
        )
    else:
        popen_handle = execute_command_nonblocking(
            [
                os.path.join(server_directory, "bin", "server"), 
                "--config", 
                os.path.join(server_directory, "config", "config.yml"),
                "--port",
                str(server_port),
            ], 
        )
    return popen_handle, server_port

def install_postgres():
    execute_command([sys.executable, "-m", "pip", "install", "psycopg2-binary"])

def install_bigquery():
    execute_command([sys.executable, "-m", "pip", "install", "google-cloud-bigquery"])

def install_snowflake():
    execute_command([sys.executable, "-m", "pip", "install", "snowflake-sqlalchemy"])

def install_s3():
    execute_command([sys.executable, "-m", "pip", "install", "pyarrow"])

def install_mysql():
    system = platform.system()
    if system == "Linux":
        if distro.id() == "ubuntu" or distro.id() == "debian":
            execute_command(["sudo", "apt-get", "install", "-y", "python3-dev", "default-libmysqlclient-dev", "build-essential"])
        elif distro.id() == "centos" or distro.id() == "rhel":
            execute_command(["sudo", "yum", "install", "-y", "python3-devel", "mysql-devel"])
        else:
            print("Unsupported distribution:", distro.id())
    elif system == "Darwin":
        execute_command(["brew", "install", "mysql"])
    else:
        print("Unsupported operating system:", system)
    
    execute_command([sys.executable, "-m", "pip", "install", "mysqlclient"])

def install_sqlserver():
    system = platform.system()
    if system == "Linux":
        if distro.id() == "ubuntu":
            execute_command(["bash", os.path.join(server_directory, "bin", "install_sqlserver_ubuntu.sh")])
        else:
            print("Unsupported distribution:", distro.id())
    elif system == "Darwin":
        execute_command(["brew", "tap", "microsoft/mssql-release", "https://github.com/Microsoft/homebrew-mssql-release"])
        execute_command(["brew", "update"])
        execute_command(["HOMEBREW_NO_ENV_FILTERING=1", "ACCEPT_EULA=Y", "brew", "install", "msodbcsql17", "mssql-tools"])
    else:
        print("Unsupported operating system:", system)
    
    execute_command([sys.executable, "-m", "pip", "install", "pyodbc"])

def install(system):
    if system == "postgres":
        install_postgres()
    elif system == "bigquery":
        install_bigquery()
    elif system == "snowflake":
        install_snowflake()
    elif system == "s3":
        install_s3()
    elif system == "mysql":
        install_mysql()
    elif system == "sqlserver":
        install_sqlserver()
    else:
        raise Exception("Unsupported system: %s" % system)

def get_apikey():
    config_file = os.path.join(server_directory, "config", "config.yml")
    with open(config_file, "r") as f:
        try:
            return yaml.safe_load(f)['apiKey']
        except yaml.YAMLError as exc:
            print(exc)
            exit(1)

def apikey():
    print(get_apikey())

def clear():
    shutil.rmtree(base_directory, ignore_errors=True)

def version():
    print(package_version)

def read_config():
    with open(os.path.join(server_directory, "config", "config.yml"), 'r') as f:
        config = yaml.safe_load(f)
    return config

def write_config(config):
    with open(os.path.join(server_directory, "config", "config.yml"), 'w') as f:
        yaml.dump(config, f)

def use_local_storage(path):
    config = read_config()

    path = os.path.join(server_directory, path)
    
    file_config = {"directory": path}
    config["storageConfig"] = {
        "type": "file",
        "fileConfig": file_config,
    }

    write_config(config)

def use_s3_storage(region, bucket, creds_path, creds_profile):
    if not bucket.startswith("s3://"):
        print('S3 path is malformed, it should be of the form s3://')
        sys.exit(1)
        
    config = read_config()
    
    s3_config = {
        "region": region,
        "bucket": bucket,
        "credentialsPath": creds_path,
        "credentialsProfile": creds_profile,
    }
    config["storageConfig"] = {
        "type": "s3",
        "s3Config": s3_config,
    }

    write_config(config)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='The Aqueduct CLI')
    subparsers = parser.add_subparsers(dest="command")
    
    start_args = subparsers.add_parser('start', help=
                               '''This starts the Aqueduct server and the UI in a blocking
                               fashion. To background the process run aqueduct start &.

                               Add --expose <IP_ADDRESS> to access the Aqueduct service from
                               an external server, where <IP_ADDRESS> is the
                               public IP of the server running the Aqueduct service.
                               ''')
    start_args.add_argument('--expose', default=False, action='store_true',
                    help="Use this option to expose the server to the public.")
    start_args.add_argument('--port', dest="port", help="Specify the port on which the Aqueduct server runs.")
    server_args = subparsers.add_parser('server', help=
                                   '''[DEPRECATED] This starts the Aqueduct server in a
                                   blocking fashion. To background the process,
                                   run aqueduct server &.''')
    server_args.add_argument('--expose', default=False, action='store_true',
                    help="Use this option to expose the server to the public.")
    ui_args = subparsers.add_parser('ui', help=
                               '''[DEPRECATED] This starts the Aqueduct UI in a blocking
                               fashion. To background the process run aqueduct
                               ui &.

                               Add --expose <IP_ADDRESS> to access the UI from
                               an external server, where <IP_ADDRESS> is the
                               public IP of the server you are running on.
                               ''')
    ui_args.add_argument('--expose', dest="expose", 
                    help="The IP address of the server running Aqueduct.") 

    install_args = subparsers.add_parser('install', help=
                             '''Install the required library dependencies for
                             an Aqueduct connector to a third-party system.''')
    install_args.add_argument('system', nargs=1, help="Supported integrations: postgres, mysql, sqlserver, s3, snowflake, bigquery.")

    apikey_args = subparsers.add_parser('apikey', help="Display your Aqueduct API key.")
    clear_args = subparsers.add_parser('clear', help="Erase your Aqueduct installation.")
    version_args = subparsers.add_parser('version', help="Retrieve the package version number.")

    storage_args = subparsers.add_parser('storage', help=
                               '''This changes the storage location for any new workflows created.
                               The change will take affect once you restart the Aqueduct server.
                               We are currently working on adding support for modifying the storage
                               location of existing workflows.
                               ''')
    storage_args.add_argument('--use', dest="storage_use", 
                                choices=["local", "s3"], required=True,
                                help="The following storage locations are supported: local, s3"
                            )
    storage_args.add_argument('--path', dest="storage_path", required=True, 
                                help=
                                '''For local storage this is the filepath of the storage directory.
                                This should be relative to the Aqueduct installation path.
                                For S3 storage this is the S3 path, which should be of the form:
                                s3://bucket/path/to/folder
                                '''
                            )
    storage_args.add_argument('--region', dest="storage_s3_region", 
                                help="The AWS S3 region where the bucket is located."
                            )
    storage_args.add_argument('--credentials', dest="storage_s3_creds", 
                                default=aws_credentials_path,
                                help='''The filepath to the AWS credentials to use.
                                It defaults to ~/.aws/credentials'''
                            )
    storage_args.add_argument('--profile', dest="storage_s3_profile",
                                default="default",
                                help='''The AWS credentials profile to use. It uses default 
                                if none is provided.'''
                            )

    args = parser.parse_args()
    sysargs = sys.argv

    if args.command == "start":
        try:
            popen_handle, server_port = start(args.expose, args.port)
            time.sleep(1)
            terminated = popen_handle.poll()
            if terminated:
                print("Server terminated due to an error.")
            else:
                print(generate_welcome_message(args.expose, server_port))
                popen_handle.wait()
        except (Exception, KeyboardInterrupt) as e:
            print(e)
            print('\nTerminating Aqueduct service...')
            popen_handle.kill()
            print('Aqueduct service successfully terminated.')
    elif args.command == "server":
        print("aqueduct ui and aqueduct server have been deprecated; please use aqueduct start to run both the UI and backend servers")
    elif args.command == "install":
        install(args.system[0]) # argparse makes this an array so only pass in value [0].
    elif args.command == "ui":
        print("aqueduct ui and aqueduct server have been deprecated; please use aqueduct start to run both the UI and backend servers")
    elif args.command == "apikey":
        apikey()
    elif args.command == "clear":
        clear()
    elif args.command == "version":
        version()
    elif args.command == "storage":
        if args.storage_use == "local":
            # Ensure that S3 args are not provided for local storage
            s3_args = ["--region", "--credentials", "--profile"]
            for s3_arg in s3_args:
                if s3_arg in sysargs:
                    print("{} should not be used with local storage".format(s3_arg))
                    sys.exit(1)
            use_local_storage(args.storage_path)
        elif args.storage_use == "s3":
            # Ensure that required S3 args are provided
            if "--region" not in sysargs:
                print("--region is required when using S3 storage")
                sys.exit(1)
            use_s3_storage(
                args.storage_s3_region, 
                args.storage_path, 
                args.storage_s3_creds, 
                args.storage_s3_profile,
            )
        else:
            print("Unsupported storage type: ", args.storage_use)
            sys.exit(1)
    elif args.command is None:
        parser.print_help()
    else:
        print("Unsupported command:", args.command)
        sys.exit(1)
