#Neural Style Transfer Prediction Runtime Code

import boto3
import cv2
import os
import numpy as np
import json
import onnxruntime as rt
import base64
import imghdr
import six
from functools import partial
import os.path
from os import path
from io import BytesIO
from PIL import Image

def get_model_onnx(runtimemodel_s3_filename="runtime_model.onnx"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/runtime_model.onnx")
    model = rt.InferenceSession(obj.get()['Body'].read())
    return model

def get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip"):
    import pickle
    from zipfile import ZipFile
    from io import BytesIO
    import os
    s3 = boto3.resource("s3")
    bucket = s3.Bucket("$bucket_name")

    zip_obj = s3.Object(bucket_name="$bucket_name", key="$unique_model_id/runtime_preprocessor.zip")
    buffer = BytesIO(zip_obj.get()["Body"].read())
    z = ZipFile(buffer)
    # Extract all the contents of zip file in current directory
    z.extractall("/tmp/")

    folderpath=os.path.dirname(os.path.abspath("/tmp/preprocessor.py"))
    file_name=os.path.basename("/tmp/preprocessor.py")

    #Then import all pkl files you want from bucket (need to generate this list from...
    # function globals)
    import os
    pickle_file_list=[]
    for file in os.listdir(folderpath):
          if file.endswith(".pkl"):
              pickle_file_list.append(os.path.join(folderpath, file))

    for i in pickle_file_list:
          objectname=str(os.path.basename(i)).replace(".pkl","")
          objects={objectname:""}
          globals()[objectname]=pickle.load(open(str(i), "rb" ) )
      # First import preprocessor function to session from preprocessor.py
    exec(open(os.path.join(folderpath,'preprocessor.py')).read(),globals())
    return preprocessor


def get_runtimedata(runtimedata_s3_filename="runtime_data.json"):

    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id"+"/"+runtimedata_s3_filename)
    runtime_data = json.load(obj.get()['Body'])

    return runtime_data



runtime_data=get_runtimedata(runtimedata_s3_filename="runtime_data.json")

preprocessor_type=runtime_data["runtime_preprocessor"]

runtime_model=runtime_data["runtime_model"]["name"]

model=get_model_onnx(runtimemodel_s3_filename='runtime_model.onnx')

# Load preprocessor

preprocessor=get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip")



def predict(event,model,preprocessor):

# Load base64 encoded image stored within "data" key of event dictionary
    body = event["body"]

    if isinstance(event["body"], six.string_types):
        body = json.loads(event["body"])

    content= body["content"]
    style= body["style"]

# Extract image file extension (e.g.-jpg, png, etc.)

    sample_content = base64.decodebytes(bytearray(content, "utf-8"))
    sample_style = base64.decodebytes(bytearray(style, "utf-8"))

    content_file_type=None
    for tf in imghdr.tests:
        image_file_type = tf(sample_content, None)
        if image_file_type:
            break
    content_file_type=image_file_type
    if(content_file_type==None):
        print("This file is not an image, please submit an image base64 encoded image file.")

    style_file_type=None
    for tf in imghdr.tests:
        image_file_type = tf(sample_style, None)
        if image_file_type:
            break
    style_file_type=image_file_type
    if(style_file_type==None):
        print("This file is not an image, please submit an image base64 encoded image file.")

# Save image to local file, read into session, and preprocess image with preprocessor function
    with open("/tmp/imagetopredict."+content_file_type, "wb") as fh:
        fh.write(base64.b64decode(content))
    content_image = preprocessor("/tmp/imagetopredict."+content_file_type)
    with open("/tmp/imagetopredict."+style_file_type, "wb") as fh:
        fh.write(base64.b64decode(style))
    style_image = preprocessor("/tmp/imagetopredict."+style_file_type)

# Generate prediction using preprocessed input data
    print("The model expects input shape:", model.get_inputs()[0].shape)
    input_name = model.get_inputs()[0].name
    res = model.run(None, {
        "placeholder": content_image,
        "placeholder_1": style_image
    })
    
    pil_img = Image.fromarray(np.uint8(res[0].squeeze(axis=0)*255))
    buffered = BytesIO()
    pil_img.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue())
    return img_str.decode("utf-8")

def handler(event, context):
    result = predict(event,model,preprocessor)
    return {"statusCode": 200,
    "headers": {
    "Access-Control-Allow-Origin" : "*",
    "Access-Control-Allow-Credentials": True,
    "Allow" : "GET, OPTIONS, POST",
    "Access-Control-Allow-Methods" : "GET, OPTIONS, POST",
    "Access-Control-Allow-Headers" : "*"
    },
    "body" : json.dumps(result)}