import json
import os
import logging

from typing import List, Dict, Set
from teradataml.analytics.valib import *
from teradataml import DataFrame
from aoa.context.model_context import ModelContext
from aoa.stats.stats_util import (
    _capture_stats,
    _NpEncoder,
    _parse_scoring_stats,
    _compute_continuous_edges
)

logger = logging.getLogger(__name__)


def record_training_stats(df: DataFrame,
                          features: List[str],
                          targets: List[str],
                          categorical: List[str] = [],
                          context: ModelContext = {},
                          feature_importance: Dict[str, float] = {},
                          category_ordinals: Set[str] = {}) -> Dict:
    """
    Compute and record the dataset statistics used for training. This information provides ModelOps with a snapshot
    of the dataset at this point in time (i.e. at the point of training). ModelOps uses this information for data and
    prediction drift monitoring. It can also be used for data quality monitoring as all of the information which is
    captured here is available to configure an alert on (e.g. max > some_threshold).

    Depending on the type of variable (categorical or continuous), different statistics and distributions are computed.
    All of this is computed in Vantage via the Vantage Analytics Library (VAL).

    Continuous Variable:
        Distribution: Histogram
        Statistics: Min, Max, Average, Skew, etc, nulls

    Categorical Variable:
        Distribution: Frequency
        Statistics: nulls

    The following example shows how you would use this function for a binary classification problem where the there
    are 3 features and 1 target. As it is classification, the target must be categorical and in this case, the features
    are all continuous.
    example usage:
        training_df = DataFrame.from_query("SELECT * from my_table")

        record_training_stats(training_df,
                              features=["feat1", "feat2", "feat3"],
                              targets=["targ1"],
                              categorical=["targ1"],
                              context=context)

    :param df: teradataml dataframe used for training with the feature and target variables
    :type df: teradataml.DataFrame
    :param features: feature variable(s) used in this training
    :type features: List[str]
    :param targets: target variable(s) used in this training
    :type targets: List[str]
    :param categorical: variable(s) (feature or target) that is categorical
    :type categorical: List[str]
    :param context: ModelContext which is associated with that training invocation
    :type context: ModelContext
    :param feature_importance: (Optional) feature importance
    :type feature_importance: Dict[str, float]
    :param category_ordinals: (Optional) categorical variable(s) which are of ordinal type
    :type category_ordinals: Set[str]
    :return: the computed data statistics
    :rtype: Dict
    :raise ValueError: if features or targets are not provided
    :raise TypeError: if df is not of type teradataml.DataFrame
    """

    logger.info("Computing training dataset statistics")

    if not features:
        raise ValueError("One or more features must be provided")

    feature_metadata_fqtn = None
    feature_metadata_group = None
    data_stats_filename = "artifacts/output/data_stats.json"

    if context:
        feature_metadata_fqtn = context.dataset_info.get_feature_metadata_fqtn()
        feature_metadata_group = context.dataset_info.feature_metadata_monitoring_group
        data_stats_filename = os.path.join(context.artifact_output_path, "data_stats.json")

    data_stats = _capture_stats(df=df,
                                features=features,
                                targets=targets,
                                categorical=categorical,
                                category_ordinals=category_ordinals,
                                feature_importance=feature_importance,
                                feature_metadata_fqtn=feature_metadata_fqtn,
                                feature_metadata_group=feature_metadata_group)

    with open(data_stats_filename, 'w+') as f:
        json.dump(data_stats, f, indent=2, cls=_NpEncoder)

    return data_stats


def record_evaluation_stats(features_df: DataFrame,
                            predicted_df: DataFrame,
                            feature_importance: Dict[str, float] = {},
                            context: ModelContext = None,
                            **kwargs) -> Dict:
    """
    Compute and record the dataset statistics used for evaluation. This information provides ModelOps with a snapshot
    of the dataset at this point in time (i.e. at the point of evaluation). ModelOps uses this information for data
    and prediction drift monitoring. It can also be used for data quality monitoring as all of the information which
    is captured here is available to configure an alert on (e.g. max > some_threshold).

    Depending on the type of variable (categorical or continuous), different statistics and distributions are computed.
    All of this is computed in Vantage via the Vantage Analytics Library (VAL).

    Continuous Variable:
        Distribution: Histogram
        Statistics: Min, Max, Average, Skew, etc, nulls

    Categorical Variable:
        Distribution: Frequency
        Statistics: nulls

    example usage:
        features_df = DataFrame.from_query("SELECT * from my_features_table")

        predicted_df = model.predict(features_df)

        record_evaluation_stats(features_df=features_df,
                                predicted_df=predicted_df,
                                context=context)

    :param features_df: dataframe containing feature variable(s) from evaluation
    :type features_df: teradataml.DataFrame
    :param predicted_df: dataframe containing predicted target variable(s) from evaluation
    :type predicted_df: teradataml.DataFrame
    :param context: ModelContext which is associated with that training invocation
    :type context: ModelContext
    :param feature_importance: (Optional) feature importance
    :type feature_importance: Dict[str, float]
    :return: the computed data statistics
    :rtype: Dict
    :raise ValueError: if the number of predictions (rows) do not match the number of features (rows)
    :raise TypeError: if features_df or predicted_df is not of type teradataml.DataFrame
    """

    logger.info("Computing evaluation dataset statistics")

    feature_metadata_fqtn = None
    feature_metadata_group = None
    output_data_stats_filename = "artifacts/output/data_stats.json"
    input_data_stats_filename = "artifacts/input/data_stats.json"

    if context:
        feature_metadata_fqtn = context.dataset_info.get_feature_metadata_fqtn()
        feature_metadata_group = context.dataset_info.feature_metadata_monitoring_group
        output_data_stats_filename = os.path.join(context.artifact_output_path, "data_stats.json")
        input_data_stats_filename = os.path.join(context.artifact_input_path, "data_stats.json")

    with open(input_data_stats_filename, 'r') as f:
        training_data_stats = json.load(f)

    data_stats = _parse_scoring_stats(features_df=features_df,
                                      predicted_df=predicted_df,
                                      data_stats=training_data_stats,
                                      feature_importance=feature_importance,
                                      feature_metadata_fqtn=feature_metadata_fqtn,
                                      feature_metadata_group=feature_metadata_group)

    # for evaluation, the core will do it (we may change this later to unify)..
    with open(output_data_stats_filename, 'w+') as f:
        json.dump(data_stats, f, indent=2, cls=_NpEncoder)

    return data_stats


def record_scoring_stats(features_df: DataFrame,
                         predicted_df: DataFrame,
                         context: ModelContext = None) -> Dict:
    """
    Compute and record the dataset statistics used for scoring. This information provides ModelOps with a snapshot
    of the dataset at this point in time (i.e. at the point of scoring). ModelOps uses this information for data
    and prediction drift monitoring. It can also be used for data quality monitoring as all of the information which
    is captured here is available to configure an alert on (e.g. max > some_threshold).

    Depending on the type of variable (categorical or continuous), different statistics and distributions are computed.
    All of this is computed in Vantage via the Vantage Analytics Library (VAL).

    Continuous Variable:
        Distribution: Histogram
        Statistics: Min, Max, Average, Skew, etc, nulls

    Categorical Variable:
        Distribution: Frequency
        Statistics: nulls

    example usage:
        features_df = DataFrame.from_query("SELECT * from my_features_table")

        predicted_df = model.predict(features_df)

        record_scoring_stats(features_df=features_df,
                            predicted_df=predicted_df,
                            context=context)

    :param features_df: dataframe containing feature variable(s) from evaluation
    :type features_df: teradataml.DataFrame
    :param predicted_df: dataframe containing predicted target variable(s) from evaluation
    :type predicted_df: teradataml.DataFrame
    :param context: ModelContext which is associated with that training invocation
    :type context: ModelContext
    :return: the computed data statistics
    :rtype: Dict
    :raise ValueError: if the number of predictions (rows) do not match the number of features (rows)
    :raise TypeError: if features_df or predicted_df is not of type teradataml.DataFrame
    """

    logger.info("Computing scoring dataset statistics")

    feature_metadata_fqtn = None
    feature_metadata_group = None
    input_data_stats_filename = "artifacts/input/data_stats.json"

    if context:
        feature_metadata_fqtn = context.dataset_info.get_feature_metadata_fqtn()
        feature_metadata_group = context.dataset_info.feature_metadata_monitoring_group
        input_data_stats_filename = os.path.join(context.artifact_input_path, "data_stats.json")
        output_data_stats_filename = os.path.join(context.artifact_output_path, "data_stats.json")

    with open(input_data_stats_filename, 'r') as f:
        training_data_stats = json.load(f)

    data_stats = _parse_scoring_stats(features_df=features_df,
                                      predicted_df=predicted_df,
                                      data_stats=training_data_stats,
                                      feature_metadata_fqtn=feature_metadata_fqtn,
                                      feature_metadata_group=feature_metadata_group)

    # for evaluation, the core will do it (we may change this later to unify)..
    with open(output_data_stats_filename, 'w+') as f:
        json.dump(data_stats, f, indent=2, cls=_NpEncoder)

    return data_stats


def compute_continuous_stats(features_df: DataFrame, continuous_features: List):
    dtypes = {r[0].lower(): r[1] for r in features_df.dtypes._column_names_and_types}

    stats = valib.Statistics(data=features_df, columns=continuous_features, stats_options="all")
    stats = stats.result.to_pandas().reset_index()

    bins = 10
    reference_edges = _compute_continuous_edges(continuous_features, stats, dtypes, bins=bins)
    edges_dict = dict(zip(continuous_features, reference_edges))

    for variable_name, edges in edges_dict.items():
        if len(edges) < bins:
            raise Exception(f"Variable {variable_name} has only {len(edges)} bins computed when 10 should "
                            f"have been computed {edges}.\n"
                            f"Please ensure the variable is not categorical (use -t categorical).")

    column_stats = {f.lower(): {"edges": edges_dict[f]} for f in edges_dict.keys()}

    return column_stats


def compute_categorical_stats(features_df: DataFrame, categorical_features: List):
    statistics = valib.Frequency(data=features_df, columns=categorical_features)
    statistics = statistics.result.to_pandas().reset_index()
    statistics = statistics.drop(statistics.columns.difference(["xcol", "xval", "xpct"]), axis=1)
    statistics["xcol"] = statistics["xcol"].str.lower()
    statistics = statistics.groupby('xcol').apply(lambda x: dict(zip(x['xval'], x['xpct']))).to_dict()

    column_stats = {f.lower(): {"categories": list(statistics[f].keys())} for f in categorical_features}

    return column_stats
