#Text Regression Prediction Runtime Code

import boto3
import pandas as pd
import os
from io import BytesIO
import pickle
import numpy as np
import json
import onnxruntime as rt
import warnings
import six

def get_model_onnx(runtimemodel_s3_filename="runtime_model.onnx"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/runtime_model.onnx")
    model = rt.InferenceSession(obj.get()['Body'].read())
    return model

def get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip"):
    import pickle
    from zipfile import ZipFile
    from io import BytesIO
    import os
    s3 = boto3.resource("s3")
    bucket = s3.Bucket("$bucket_name")
    
    zip_obj = s3.Object(bucket_name="$bucket_name", key="$unique_model_id/runtime_preprocessor.zip")
    buffer = BytesIO(zip_obj.get()["Body"].read())
    z = ZipFile(buffer)
    # Extract all the contents of zip file in current directory
    z.extractall("/tmp/")
    
    folderpath=os.path.dirname(os.path.abspath("/tmp/preprocessor.py"))
    file_name=os.path.basename("/tmp/preprocessor.py")

    #Then import all pkl files you want from bucket (need to generate this list from...
    # function globals)
    import os
    pickle_file_list=[]
    for file in os.listdir(folderpath):
          if file.endswith(".pkl"):
              pickle_file_list.append(os.path.join(folderpath, file))

    for i in pickle_file_list: 
          objectname=str(os.path.basename(i)).replace(".pkl","")
          objects={objectname:""}
          globals()[objectname]=pickle.load(open(str(i), "rb" ) )
      # First import preprocessor function to session from preprocessor.py
    exec(open(os.path.join(folderpath,'preprocessor.py')).read(),globals())
    return preprocessor

def get_runtimedata(runtimedata_s3_filename="runtime_data.json"):

      s3 = boto3.resource('s3')
      obj = s3.Object("$bucket_name", "$unique_model_id"+"/"+runtimedata_s3_filename)
      runtime_data = json.load(obj.get()['Body'])

      return runtime_data


runtime_data=get_runtimedata(runtimedata_s3_filename="runtime_data.json")

preprocessor_type=runtime_data["runtime_preprocessor"]

runtime_model=runtime_data["runtime_model"]["name"]

# Load model
model=get_model_onnx(runtimemodel_s3_filename='runtime_model.onnx')

# Load preprocessor
preprocessor=get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip")


def predict(event,model,preprocessor):
    body = event["body"]
    if isinstance(event["body"], six.string_types):
        body = json.loads(event["body"])
        print(body["data"])
        bodynew = pd.Series(body["data"])
    else:
        print(body["data"])
        bodynew = pd.Series(body["data"])
        print(bodynew)
        
    sess=model
    def predict_classes(x): # adjusted from keras github code
            proba=x
            if proba.shape[-1] > 1:
              return proba.argmax(axis=-1)
            else:
              return (proba > 0.5).astype("int32")
    input_name = sess.get_inputs()[0].name

    input_data = preprocessor(bodynew).astype(np.float32) #needs to be float32

    res = sess.run(None,  {input_name: input_data})
    prob = res[0]
    return prob.tolist()[0]

def handler(event, context):
    result = predict(event,model,preprocessor)
    return {"statusCode": 200,
    "headers": {
    "Access-Control-Allow-Origin" : "*",
    "Access-Control-Allow-Credentials": True,
    "Allow" : "GET, OPTIONS, POST",
    "Access-Control-Allow-Methods" : "GET, OPTIONS, POST",
    "Access-Control-Allow-Headers" : "*"
    },
    "body": json.dumps(result)}
