#!/usr/bin/env python3
import requests
import os
import hashlib
import shutil
import tarfile
from tqdm import tqdm

# simply download tar file to specified dir
def download_tar(installation_dirs):
    install_dir = installation_dirs[0]
    tar_link = "https://roms8.s3.us-east-2.amazonaws.com/Roms.tar.gz"
    downloaded_tar = requests.get(tar_link, stream=True)
    tar_file_title = install_dir + "Roms.tar.gz"
    tar_file = open(tar_file_title, "wb")
    total_file_size = int(downloaded_tar.headers['Content-Length'])
    bars = 20
    download_chunk_size = int(total_file_size / bars)
    pbar_format = "{desc}:{percentage:3.0f}%|{bar}|{elapsed}{postfix}"
    for chunk in tqdm(downloaded_tar.iter_content(chunk_size=download_chunk_size), bar_format=pbar_format, total=bars, desc="Downloading ROMs", leave=False):
        tar_file.write(chunk)
    tar_file.close()

# given the location of a ROMs.tar file, extract its contents into a singular folder
def extract_tar_content(installation_dirs):
    # extract tar files
    # unzip each zip
    # calculate checksum of each tar file
    install_dir = installation_dirs[0]
    with tarfile.open(install_dir+"Roms.tar.gz") as tar:
        tar.extractall(install_dir)


def transfer_rom_files(installation_dirs, checksum_map):
    # go through every ROM in install_dir/delete/
    # if the ROM file matches a checksum, store in install dir
    install_dir = installation_dirs[0]
    zip_dir = install_dir + "ROM/"
    for subdir, _, files in os.walk(zip_dir):
        for file in files:
            hash_md5 = hashlib.md5()
            with open(os.path.join(subdir, file), "rb") as f:
                for chunk in iter(lambda: f.read(4096), b""):
                    hash_md5.update(chunk)
                d = str(hash_md5.hexdigest())
                if d in checksum_map:
                    # transfer file here to name in checksum map
                    game_name = checksum_map[d][0:-4]
                    game_subdir = install_dir+game_name+"/"
                    if not os.path.exists(game_subdir):
                        os.mkdir(game_subdir)
                    os.rename(os.path.join(subdir, file), os.path.join(game_subdir, checksum_map[d]))
                    del checksum_map[d]

def clean_tar_files(installation_dirs):
    # delete Roms.tar
    # delete extracted HC ROMS.zip
    # delete extracted ROMS.zip
    # delete unzipped delete folder
    install_dir = installation_dirs[0]
    if os.path.exists(os.path.join(install_dir, "Roms.tar.gz")):
        os.remove(os.path.join(install_dir, "Roms.tar.gz"))
    if os.path.exists(os.path.join(install_dir, "ROM/")):
        shutil.rmtree(os.path.join(install_dir, "ROM/"))

def main(license_accepted=False, specific_dir=None):
    ale_installed = True
    multi_ale_installed = True
    try:
        import ale_py
    except ImportError:
        ale_installed = False
    try:
        import multi_agent_ale_py
    except ImportError:
        multi_ale_installed = False

    installation_dirs = []

    if ale_installed:
        ale_install_dir = ale_py.__file__
        if ale_install_dir is not None:
            ale_install_dir = ale_install_dir[:-11] + "ROM/"
            installation_dirs.append(ale_install_dir)
        else:
            ale_installed = False
    else:
        ale_install_dir = None

    if multi_ale_installed:
        mulit_ale_install_dir = multi_agent_ale_py.__file__
        if mulit_ale_install_dir  is not None:
            mulit_ale_install_dir = mulit_ale_install_dir[:-11] + "ROM/"
            installation_dirs.append(mulit_ale_install_dir)
        else:
            multi_ale_installed = False
    else:
        mulit_ale_install_dir = None

    if not ale_installed and not multi_ale_installed:
        print("Neither ale_py or multi_ale_py installed, quitting.")
        quit()

    if specific_dir:
        dir_path = os.path.abspath(os.path.join(specific_dir, "ROM/")) + "/"
        installation_dirs = [dir_path]
        ale_install_dir = dir_path
        mulit_ale_install_dir = dir_path

    __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))

    checksum_file = "checksums.txt"
    ch = open(os.path.join(__location__, checksum_file), "rb")
    checksum_map = {}
    for c in ch:
        payload = c.split()
        payload[1] = payload[1].decode("utf-8")
        payload[0] = payload[0].decode("utf-8")
        checksum_map[payload[0]] = payload[1]

    license_text = ""
    if ale_installed:
        license_text += ale_install_dir + "\nfor use with ALE-Py (and Gym)"
    if ale_installed and multi_ale_installed:
        license_text += " and also\n\t"
    if multi_ale_installed:
        license_text += mulit_ale_install_dir + "\nfor use with Multi-Agent-ALE-py."
    print("AutoROM will download the Atari 2600 ROMs from",
        ". \nThey will be installed to\n\t" +
        license_text + " Existing ROMS will be overwritten.")
    if not license_accepted:
        ans = input("\nI own a license to these Atari 2600 ROMs, agree not to "+
            "distribute these ROMS, and wish to proceed (Y or N). ")


        if ans != "Y" and ans != "y":
            quit()

    if not os.path.exists(installation_dirs[0]):
        os.makedirs(installation_dirs[0])
    else:
        shutil.rmtree(installation_dirs[0])
        os.makedirs(installation_dirs[0])

    download_tar(installation_dirs)
    extract_tar_content(installation_dirs)
    transfer_rom_files(installation_dirs, checksum_map)
    clean_tar_files(installation_dirs)

    # copy into second_dir
    if len(installation_dirs) > 1:
        for secondary in installation_dirs[1:]:
            if os.path.exists(secondary):
                shutil.rmtree(secondary)
            shutil.copytree(installation_dirs[0], secondary)

    for ch in checksum_map:
        print("Missing: ", checksum_map[ch])
    print("Done!")

if __name__ == "__main__":
    import sys
    import argparse

    parser = argparse.ArgumentParser(description="Process arguments")
    parser.add_argument(
        "-v", "--accept", action="store_true", help="Accept license agreement"
    )
    parser.add_argument(
        "-d", "--dir", type=str, help="Installation directory"
    )
    parser.set_defaults(accept=False, dir=None)

    args = parser.parse_args()
    main(args.accept, specific_dir=args.dir)
