#Tabular Classification Prediction Runtime Code

import boto3
import pandas as pd
import os
import numpy as np
import onnxruntime as rt
import json

def get_model_onnx(runtimemodel_s3_filename="runtime_model.onnx"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/runtime_model.onnx")
    model = rt.InferenceSession(obj.get()['Body'].read())
    return model
    
def get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip"):
    import pickle
    from zipfile import ZipFile
    from io import BytesIO
    import os
    s3 = boto3.resource("s3")
    bucket = s3.Bucket("$bucket_name")
    
    zip_obj = s3.Object(bucket_name="$bucket_name", key="$unique_model_id/runtime_preprocessor.zip")
    buffer = BytesIO(zip_obj.get()["Body"].read())
    z = ZipFile(buffer)
    # Extract all the contents of zip file in current directory
    z.extractall("/tmp/")
    
    folderpath=os.path.dirname(os.path.abspath("/tmp/preprocessor.py"))
    file_name=os.path.basename("/tmp/preprocessor.py")

    #Then import all pkl files you want from bucket (need to generate this list from...
    # function globals)
    import os
    pickle_file_list=[]
    for file in os.listdir(folderpath):
          if file.endswith(".pkl"):
              pickle_file_list.append(os.path.join(folderpath, file))

    for i in pickle_file_list: 
          objectname=str(os.path.basename(i)).replace(".pkl","")
          objects={objectname:""}
          globals()[objectname]=pickle.load(open(str(i), "rb" ) )
      # First import preprocessor function to session from preprocessor.py
    exec(open(os.path.join(folderpath,'preprocessor.py')).read(),globals())
    return preprocessor

def get_runtimedata(runtimedata_s3_filename="runtime_data.json"):

      s3 = boto3.resource('s3')
      obj = s3.Object("$bucket_name", "$unique_model_id"+"/"+runtimedata_s3_filename)
      runtime_data = json.load(obj.get()['Body'])

      return runtime_data


runtime_data=get_runtimedata(runtimedata_s3_filename="runtime_data.json")

preprocessor_type=runtime_data["runtime_preprocessor"]

runtime_model=runtime_data["runtime_model"]["name"]

# Load model
model=get_model_onnx(runtimemodel_s3_filename='runtime_model.onnx')

# Load preprocessor
preprocessor=get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip")


def predict(event,model,preprocessor):
    body = event["body"]
    labels=$labels
    import six
    if isinstance(event["body"], six.string_types):
        body = json.loads(event["body"])
        print(body["data"])
        bodydata = pd.DataFrame.from_dict(body["data"])
    else:
        print(body["data"])
        bodydata = pd.DataFrame.from_dict(body["data"])
        print(bodydata)

    def predict_classes(x): # adjusted from keras github code
        if len(x.shape)==2:
            index = x.argmax(axis=-1)
            return list(map(lambda x: labels[x], index))
        else:
            return list(x)
    input_name = model.get_inputs()[0].name
    print(input_name)
    input_data = preprocessor(bodydata).astype('float32') #needs to be float32
    print(input_data)
    res=model.run(None,  {input_name: input_data})
    prob = res[0]
    print(prob)
    return predict_classes(prob)

def handler(event, context):
        result = predict(event,model,preprocessor)
        return {"statusCode": 200,
        "headers": {
        "Access-Control-Allow-Origin" : "*",
        "Access-Control-Allow-Credentials": True,
        "Allow" : "GET, OPTIONS, POST",
        "Access-Control-Allow-Methods" : "GET, OPTIONS, POST",
        "Access-Control-Allow-Headers" : "*"
        },
        "body": json.dumps(result)}
