#!/usr/bin/env python
# Copyright (c) atomate Development Team.

import argparse
import ast
import os
import sys
from datetime import datetime

import yaml
from fireworks import LaunchPad
from monty.serialization import loadfn
from pymatgen.core import Lattice, Structure
from pymatgen.ext.matproj import MPRester
from pymatgen.util.testing import PymatgenTest

from atomate.common.powerups import add_namefile, add_tags
from atomate.utils.utils import get_wf_from_spec_dict, load_class
from atomate.vasp.workflows.presets import core

default_yaml = """fireworks:
- fw: atomate.vasp.fireworks.core.OptimizeFW
- fw: atomate.vasp.fireworks.core.StaticFW
  params:
    parents: 0
- fw: atomate.vasp.fireworks.core.NonSCFFW
  params:
    parents: 1
    mode: uniform
- fw: atomate.vasp.fireworks.core.NonSCFFW
  params:
    parents: 1
    mode: line
common_params:
  db_file: None
"""

presets_dir = os.path.dirname(os.path.abspath(core.__file__))
lpad = LaunchPad.auto_load()


def add_to_lpad(workflow, write_namefile=False):
    """
    Add the workflow to the launchpad

    Args:
        workflow (Workflow): workflow for db insertion
        write_namefile (bool): If set an empty file with the name
            "FW--<fw.name>" will be written to the launch directory
    """
    workflow = add_namefile(workflow) if write_namefile else workflow
    lpad.add_wf(workflow)


def _get_wf(args, structure):
    if args.spec_file:
        spec_path = args.spec_file
        if args.library:
            if args.library.lower() == "vasp":
                spec_path = os.path.join(
                    presets_dir, "..", "base", "library", spec_path
                )
            else:
                raise ValueError(f"Unknown library: {args.library}")
        d = loadfn(spec_path)
        return get_wf_from_spec_dict(structure, d, args.common_param_updates)

    elif args.preset:
        if args.library:
            if args.library.lower() == "vasp":
                modname = "atomate.vasp.workflows.presets.core"
                funcname = args.preset
            else:
                modname, funcname = args.preset.rsplit(".", 1)

        mod = __import__(modname, globals(), locals(), [str(funcname)], 0)
        func = getattr(mod, funcname)
        return func(structure)

    else:
        d = yaml.load(default_yaml)
        return get_wf_from_spec_dict(structure, d, args.common_param_updates)


def add_wf(args):
    for f in args.files:
        if not args.mp:
            s = Structure.from_file(f)
        else:
            s = MPRester().get_structure_by_material_id(f)
        wf = _get_wf(args, s)
        add_to_lpad(wf, write_namefile=False)


def submit_test_suite(args):
    """
    Creates a test suite of workflows and adds it to the LaunchPad
    """
    dt = datetime.utcnow()
    if args.reset:
        lpad.reset(password="", require_password=False)

    # Structures for standard workflow
    compounds = ["Si", "CsCl"]
    structs = [PymatgenTest.get_structure(c) for c in compounds]
    structs.append(
        Structure.from_spacegroup(
            "Fm-3m", Lattice.cubic(4.204), ["Ni", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]]
        )
    )
    structs = [s.get_primitive_structure() for s in structs]
    scopy = structs[0].copy()
    scopy.perturb(0.1)
    # Compounds for specific workflows, if not specified uses Si
    custom_args = {
        "wf_elastic_constant": [PymatgenTest.get_structure("Sn")],
        "wf_piezoelectric_constant": [PymatgenTest.get_structure("SrTiO3")],
        "wf_nudged_elastic_band": [[structs[0], scopy], structs[0]],
    }

    # Add default workflows
    d = yaml.load(default_yaml)
    wfs = [get_wf_from_spec_dict(s, d) for s in structs]
    # Add preset workflows
    for name, wf_func in core.__dict__.items():
        if callable(wf_func) and name[:2] == "wf":
            if name in custom_args:
                wfs.append(wf_func(*custom_args[name]))
            else:
                wfs.append(wf_func(structs[0]))
    wfs = [add_tags(wf, f"test set {dt}") for wf in wfs]
    for wf in wfs:
        add_to_lpad(wf, write_namefile=False)


def verify_test_suite(args):
    tags = lpad.fireworks.distinct("spec.tags", {"spec.tags": {"$regex": "test set"}})
    for tag in tags:
        pipeline = [
            {"$match": {"spec.tags": tag}},
            {"$project": {"state": 1}},
            {"$group": {"_id": "$state", "count": {"$sum": 1}}},
        ]
        rep_dict = {k["_id"]: k["count"] for k in lpad.fireworks.aggregate(pipeline)}
        dt = "".join(tag.split()[2:])
        print(f"Test set submitted at {dt}")
        if args.report:
            print("   Test fw state summary:")
            for state, count in rep_dict.items():
                print(f"   -{state}: {count}")
        count = sum(rep_dict.values())
        if rep_dict.get("COMPLETED", None) == count:
            print("   All test workflows successfully completed")
        elif not rep_dict.get("FIZZLED", None):
            print("   No fizzled workflows, testing not yet complete.")
        else:
            print("   {} fireworks have have fizzled.".format(rep_dict["FIZZLED"]))


def powerup_workflow(args):
    """
    Simple function that will powerup a workflow in the database
    """
    if args.wf_id and args.query:
        raise ValueError("Only one of --wf_id and --query may be specified")
    elif args.wf_id:
        wfs = [lpad.get_wf_by_fw_id(int(args.wf_id))]
    elif args.query:
        query = ast.literal_eval(args.query)
        wf_ids = lpad.get_wf_ids(query)
        wfs = [lpad.get_wf_by_fw_id(wf_id) for wf_id in wf_ids]
    else:
        raise ValueError("At least one of --wf_id or --query must be specified")
    powerup_fn = load_class(args.module, args.name)
    powerup_kwargs = ast.literal_eval(args.powerup_kwargs)
    for wf in wfs:
        wf = powerup_fn(wf, **powerup_kwargs)
        for fw in wf.fws:
            lpad.update_spec([fw.fw_id], {"_tasks": fw.as_dict()["spec"]["_tasks"]})


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="atwf is a convenient script to add workflows using a "
        "simple YAML spec.",
        epilog="Author: Shyue Ping Ong",
    )

    subparsers = parser.add_subparsers()

    padd = subparsers.add_parser("add", help="Add workflows.")
    padd.add_argument(
        "-l",
        "--library",
        dest="library",
        type=str,
        help="If this option is set, the path to the spec file"
        "is taken with respect to the atomate base library."
        "Use 'vasp' for the VASP library of workflows",
    )
    padd.add_argument(
        "-s",
        "--spec",
        dest="spec_file",
        type=str,
        help="Specify workflow type using YAML/JSON spec file.",
    )
    padd.add_argument(
        "-p",
        "--preset",
        dest="preset",
        type=str,
        help="Specify workflow type using preset function",
    )
    padd.add_argument(
        "-m",
        "--mp",
        dest="mp",
        action="store_true",
        help="If this option is set, the files argument is "
        "interpreted as a list of Materials Project IDS. "
        "Note that your MAPI_KEY environment variable must "
        "be set to get structures from the Materials "
        "Project.",
    )
    padd.add_argument(
        "-c",
        "--common_params",
        dest="common_param_updates",
        help='Set to a dict-like string, e.g. \'{"a":"b"}\', to set common params',
    )
    padd.add_argument(
        "files",
        metavar="files",
        type=str,
        nargs="+",
        help="Structures to add workflows for.",
    )
    padd.set_defaults(func=add_wf, common_param_updates="{}")

    ptest = subparsers.add_parser("test", help="Add test suite.")
    ptest.add_argument(
        "-r",
        "--reset",
        dest="reset",
        action="store_true",
        help="If this option is set, launchpad will be reset.",
    )
    ptest.set_defaults(func=submit_test_suite)

    pverify = subparsers.add_parser("verify", help="verify test results.")

    pverify.add_argument(
        "-r",
        "--report",
        dest="report",
        action="store_true",
        help="Use this option to print summary of test workflow " "states.",
    )
    pverify.set_defaults(func=verify_test_suite)

    ppowerup = subparsers.add_parser("powerup", help="Apply a powerup to a workflow")
    ppowerup.add_argument("-i", "--wf_id", help="Workflow id to powerup")
    ppowerup.add_argument("-q", "--query", help="Query for workflows")
    ppowerup.add_argument(
        "-n", "--name", help="Name of powerup to apply, " "e.g. add_modify_incar"
    )
    ppowerup.add_argument(
        "-m",
        "--module",
        default="atomate.vasp.powerups",
        help="Module containing powerup",
    )
    ppowerup.add_argument(
        "-pk",
        "--powerup_kwargs",
        default="{}",
        help='powerup keyword arguments, e.g. \'{"incar_update": {"ENCUT": 700}}\'',
    )
    ppowerup.set_defaults(func=powerup_workflow)

    args = parser.parse_args()

    if hasattr(args, "common_param_updates"):
        args.common_param_updates = ast.literal_eval(
            args.common_param_updates
        )  # str->dict

    try:
        a = getattr(args, "func")
    except AttributeError:
        parser.print_help()
        sys.exit(0)
    args.func(args)
