import boto3
import pandas as pd
import os
import numpy as np
import onnxruntime as rt
import json
import sklearn
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error
from math import sqrt
import pickle
import six

def get_ytestdata(ytest_s3_filename="ytest.pkl"):
  s3 = boto3.resource("s3")
  bucket = s3.Bucket("$bucket_name")

  with open("/tmp/ytest.pkl", "wb") as ytestfo:
      bucket.download_fileobj("$unique_model_id/ytest.pkl",  ytestfo)
  ytestdata = pickle.load(open("/tmp/ytest.pkl", "rb" ) )
  return ytestdata

def model_eval_metrics(y_true, y_pred,classification="TRUE"):
     if classification=="TRUE":
        accuracy_eval = accuracy_score(y_true, y_pred)
        f1_score_eval = f1_score(y_true, y_pred,average="macro",zero_division=0)
        precision_eval = precision_score(y_true, y_pred,average="macro",zero_division=0)
        recall_eval = recall_score(y_true, y_pred,average="macro",zero_division=0)
        mse_eval = 0
        rmse_eval = 0
        mae_eval = 0
        r2_eval = 0
        metricdata = {'accuracy': [accuracy_eval], 'f1_score': [f1_score_eval], 'precision': [precision_eval], 'recall': [recall_eval], 'mse': [mse_eval], 'rmse': [rmse_eval], 'mae': [mae_eval], 'r2': [r2_eval]}
        finalmetricdata = pd.DataFrame.from_dict(metricdata)
     else:
        accuracy_eval = 0
        f1_score_eval = 0
        precision_eval = 0
        recall_eval = 0
        mse_eval = mean_squared_error(y_true, y_pred)
        rmse_eval = sqrt(mean_squared_error(y_true, y_pred))
        mae_eval = mean_absolute_error(y_true, y_pred)
        r2_eval = r2_score(y_true, y_pred)
        metricdata = {'accuracy': [accuracy_eval], 'f1_score': [f1_score_eval], 'precision': [precision_eval], 'recall': [recall_eval], 'mse': [mse_eval], 'rmse': [rmse_eval], 'mae': [mae_eval], 'r2': [r2_eval]}
        finalmetricdata = pd.DataFrame.from_dict(metricdata)
     return finalmetricdata.to_dict('records')[0]

ytestdata=get_ytestdata(ytest_s3_filename="ytest.pkl")

def evaluate_model(event,ytestdata):
    body = event["body"]
    print(body)
    import six
    if isinstance(event["body"], six.string_types):
        prediction_list = json.loads(event["body"])
        print(prediction_list)
    else:
        prediction_list = event["body"]
        print(prediction_list)

    result=model_eval_metrics(ytestdata,prediction_list,classification="FALSE")
    return result

def handler(event, context):
        result = evaluate_model(event,ytestdata)
        return {"statusCode": 200,
        "headers": {
        "Access-Control-Allow-Origin" : "*",
        "Access-Control-Allow-Credentials": True,
        "Allow" : "GET, OPTIONS, POST",
        "Access-Control-Allow-Methods" : "GET, OPTIONS, POST",
        "Access-Control-Allow-Headers" : "*"
        },
        "body": json.dumps(result)}