#Image Regression Prediction Runtime Code

import boto3
import cv2
import os
import numpy as np
import json
import onnxruntime as rt
import base64
import imghdr
import six
from functools import partial
import os.path
from os import path


def get_model_onnx(runtimemodel_s3_filename="runtime_model.onnx"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/runtime_model.onnx")
    model = rt.InferenceSession(obj.get()['Body'].read())
    return model

def get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip"):
    import pickle
    from zipfile import ZipFile
    from io import BytesIO
    import os
    s3 = boto3.resource("s3")
    bucket = s3.Bucket("$bucket_name")
    
    zip_obj = s3.Object(bucket_name="$bucket_name", key="$unique_model_id/runtime_preprocessor.zip")
    buffer = BytesIO(zip_obj.get()["Body"].read())
    z = ZipFile(buffer)
    # Extract all the contents of zip file in current directory
    z.extractall("/tmp/")
    
    folderpath=os.path.dirname(os.path.abspath("/tmp/preprocessor.py"))
    file_name=os.path.basename("/tmp/preprocessor.py")

    #Then import all pkl files you want from bucket (need to generate this list from...
    # function globals)
    import os
    pickle_file_list=[]
    for file in os.listdir(folderpath):
          if file.endswith(".pkl"):
              pickle_file_list.append(os.path.join(folderpath, file))

    for i in pickle_file_list: 
          objectname=str(os.path.basename(i)).replace(".pkl","")
          objects={objectname:""}
          globals()[objectname]=pickle.load(open(str(i), "rb" ) )
      # First import preprocessor function to session from preprocessor.py
    exec(open(os.path.join(folderpath,'preprocessor.py')).read(),globals())
    return preprocessor


def get_runtimedata(runtimedata_s3_filename="runtime_data.json"):

      s3 = boto3.resource('s3')
      obj = s3.Object("$bucket_name", "$unique_model_id"+"/"+runtimedata_s3_filename)
      runtime_data = json.load(obj.get()['Body'])

      return runtime_data


runtime_data=get_runtimedata(runtimedata_s3_filename="runtime_data.json")

preprocessor_type=runtime_data["runtime_preprocessor"]

runtime_model=runtime_data["runtime_model"]["name"]

model=get_model_onnx(runtimemodel_s3_filename='runtime_model.onnx')

# Load preprocessor

preprocessor=get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip")


def predict(event,model,preprocessor):

# Load base64 encoded image stored within "data" key of event dictionary
  print(event["body"])
  body = event["body"]
  if isinstance(event["body"], six.string_types):
        body = json.loads(event["body"])
        
  bodydata=body["data"]

# Extract image file extension (e.g.-jpg, png, etc.)
  sample = base64.decodebytes(bytearray(bodydata, "utf-8"))

  for tf in imghdr.tests:
    image_file_type = tf(sample, None)
    if image_file_type:
      break;
  image_file_type=image_file_type
 
  if(image_file_type==None):
    print("This file is not an image, please submit an image base64 encoded image file.")

# Save image to local file, read into session, and preprocess image with preprocessor function

  with open("/tmp/imagetopredict."+image_file_type, "wb") as fh:
    fh.write(base64.b64decode(bodydata))

  input_data = preprocessor("/tmp/imagetopredict."+image_file_type)

# Generate prediction using preprocessed input data
  print("The model expects input shape:", model.get_inputs()[0].shape)
  input_name = model.get_inputs()[0].name

  res = model.run(None, {input_name: input_data})
 
  #extract predicted value
  
  result = res[0].tolist()[0]

  os.remove("/tmp/imagetopredict."+image_file_type)

  return result

def handler(event, context):

    result = predict(event,model,preprocessor)
    return {"statusCode": 200,
    "headers": {
    "Access-Control-Allow-Origin" : "*",
    "Access-Control-Allow-Credentials": True,
    "Allow" : "GET, OPTIONS, POST",
    "Access-Control-Allow-Methods" : "GET, OPTIONS, POST",
    "Access-Control-Allow-Headers" : "*"
    },
    "body" : json.dumps(result)}
