import pycuda.autoinit
import numpy as np
import pycuda.driver as cuda
import tensorrt as trt
import os


TRT_LOGGER = trt.Logger()  # This logger is required to build an engine

def get_img_np_nchw(filename):
    image = cv2.imread(filename)
    image=np.float32(image)
    image_cv = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    image_cv = cv2.cvtColor(image_cv, cv2.COLOR_GRAY2BGR)    
    image_cv = cv2.resize(image_cv, (640, 640))
    image_cv -= (105, 110, 110)
    image_cv = image_cv.transpose(2, 0, 1) 
    img_np_nchw = np.expand_dims(image_cv, axis=0)
#     print(np.shape(img_np_nchw))    
    return img_np_nchw

class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem):
        """Within this context, host_mom means the cpu memory and device means the GPU memory
        """
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()

def allocate_buffers(engine):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()
    for binding in engine:
        size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # Allocate host and device buffers
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to device bindings.
        bindings.append(int(device_mem))
        # Append to the appropriate list.
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))
    return inputs, outputs, bindings, stream


def get_engine(max_batch_size=1, onnx_file_path="", engine_file_path="", \
               fp16_mode=False, int8_mode=False, save_engine=False,verbose=False
               ):
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""

    def build_engine(max_batch_size, save_engine):
        """Takes an ONNX file and creates a TensorRT engine to run inference with"""
        with trt.Builder(TRT_LOGGER) as builder, \
                builder.create_network() as network, \
                trt.OnnxParser(network, TRT_LOGGER) as parser:

            builder.max_workspace_size = 1 << 30  # Your workspace size
            builder.max_batch_size = max_batch_size
            # pdb.set_trace()
            builder.fp16_mode = fp16_mode  # Default: False
            builder.int8_mode = int8_mode  # Default: False
            if int8_mode:
                # To be updated
                raise NotImplementedError

            # Parse model file
            if not os.path.exists(onnx_file_path):
                quit('ONNX file {} not found'.format(onnx_file_path))

            print('Loading ONNX file from path {}...'.format(onnx_file_path))
            with open(onnx_file_path, 'rb') as model:
                print('Beginning ONNX file parsing')
                parser.parse(model.read())

            print('Completed parsing of ONNX file')
            print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))

            engine = builder.build_cuda_engine(network)
            print("Completed creating Engine")

            if save_engine:
                with open(engine_file_path, "wb") as f:
                    f.write(engine.serialize())
            return engine

    if os.path.exists(engine_file_path):
        # If a serialized engine exists, load it instead of building a new one.
        if verbose:
            print("Reading engine from file {}".format(engine_file_path))
        with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            return runtime.deserialize_cuda_engine(f.read())
    else:
        return build_engine(max_batch_size, save_engine)
    
def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
    # Transfer data from CPU to the GPU.
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference.
    context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU.
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return only the host outputs.
    return [out.host for out in outputs]
    
def postprocess_the_outputs(h_outputs, shape_of_output):
    h_outputs = h_outputs.reshape(*shape_of_output)
    return h_outputs


class trt_engine():
    def __init__(self,fp16=True,int8=False,trt_engine_path = './alpr/Models/test_int8.trt',onnx_model_path = './alpr/Models/test.onnx'):
        self.max_batch_size = 1
        self.trt_engine_path = trt_engine_path
        self.onnx_model_path = onnx_model_path
        self.fp16=fp16
        self.int8=int8
        
        # Build an engine
        self.engine = get_engine(self.max_batch_size, self.onnx_model_path, self.trt_engine_path, self.fp16,self.int8)             
        # Allocate buffers for input and output
        self.inputs, self.outputs, self.bindings, self.stream = allocate_buffers(self.engine) # input, output: host # bindings
        
        # Create the context for this engine
        self.context=self.engine.create_execution_context()

    def load_to_buffers(self,input_data):        

        # Load_to_buffers
        self.inputs[0].host = input_data.reshape(-1)

    def go_inference(self):
        trt_outputs = do_inference(self.context, bindings=self.bindings, inputs=self.inputs, outputs=self.outputs, stream=self.stream)#numpy data
            
        loc = postprocess_the_outputs(trt_outputs[0], (self.max_batch_size, 16800, 4))
        landms = postprocess_the_outputs(trt_outputs[1], (self.max_batch_size, 16800, 8))
        conf = postprocess_the_outputs(trt_outputs[2], (self.max_batch_size, 16800, 2))
        return loc,conf,landms
        

