import torch
from torch.autograd import Variable

from . import ocr_utils as utils
from . import alphabets

import numpy as np
import cv2

from tritonclient.utils import *

import tritonclient.grpc as grpcclient
import tritonclient.utils.shared_memory as shm


#from .ipconfig import OCR_ip
OCR_ip = "127.0.0.1:8001"
model_name = "resnet32_ctc_plate_ocr"
Version_id=str(1)


class TrtCtcOcrSHM(object):
    def __init__(self,model=model_name, debug = False, do_unregister = False):            
        self.model_name=model
        self.Version_id=str(1)
        self.size=(480, 16)
        alphabet ="""ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-"""
        self.converter = utils.strLabelConverter(alphabet)

        self.client = grpcclient.InferenceServerClient(OCR_ip, debug)

        # 註冊 shared memort
        # To make sure no shared memory regions are registered with the
        # server.
        if(do_unregister):
            self.client.unregister_system_shared_memory()
            self.client.unregister_cuda_shared_memory()

        self.input_byte_size = 1 * self.size[0] * self.size[1] * 4
        self.output_byte_size = 1 * 30 * 38 * 4

        # Create shared memory region for output and store shared memory handle
        self.shm_op_handle = shm.create_shared_memory_region("ctcocr_output_data",
                                                        "/ctcocr_output_simple",
                                                        self.output_byte_size)

        # Register shared memory region for outputs with Triton Server
        self.client.register_system_shared_memory("ctcocr_output_data", "/ctcocr_output_simple",
                                                    self.output_byte_size)
        
        # Create shared memory region for input and store shared memory handle
        self.shm_ip_handle = shm.create_shared_memory_region("ctcocr_input_data",
                                                        "/ctcocr_input_simple",
                                                        self.input_byte_size)
        
        # Register shared memory region for inputs with Triton Server
        self.client.register_system_shared_memory("ctcocr_input_data", "/ctcocr_input_simple",
                                                        self.input_byte_size)



    def preprocess(self,img_raw):
        img = cv2.resize(img_raw,self.size, interpolation=cv2.INTER_LINEAR)
        img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)                    
        img = np.float32(img)/255
        img = np.subtract(img,0.5)/0.5
        img = np.expand_dims(img,0)
        return img

    def postprocess(self,infer_results):
        preds = torch.from_numpy(infer_results).transpose(0,1)
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        preds_size = Variable(torch.LongTensor([preds.size(0)]))
        sim_pred = self.converter.decode(preds.data, preds_size.data, raw=False)
        return sim_pred

    def recognition(self,image):
        image = self.preprocess(image)
        
        input0_data = np.expand_dims(image,0)
        
        # Put input data values into shared memory
        shm.set_shared_memory_region(self.shm_ip_handle, [input0_data])
        
        inputs = [
            grpcclient.InferInput("input0",
                                 input0_data.shape,
                                 np_to_triton_dtype(input0_data.dtype))
        ]
        inputs[0].set_shared_memory("ctcocr_input_data", self.input_byte_size)


        outputs = [
                grpcclient.InferRequestedOutput("output0"),
            ]

        outputs[0].set_shared_memory("ctcocr_output_data", self.output_byte_size*2)
     

        response = self.client.infer(self.model_name,
                                    inputs,
                                    request_id=self.Version_id,
                                    outputs=outputs)
        
        
        output_results = response.get_output("output0")
            
        if output_results is not None:
            output_data = shm.get_contents_as_numpy(
                self.shm_op_handle, triton_to_np_dtype(output_results.datatype),
                output_results.shape)
        
            return self.postprocess(output_data)
        else:
            return None
