import torch
from torch.autograd import Variable

from . import ocr_utils as utils
from . import alphabets

import numpy as np
import cv2

from tritonclient.utils import *
import tritonclient.grpc as grpclient
#from .ipconfig import OCR_ip
OCR_ip = "127.0.0.1:8001"
model_name = "resnet32_ctc_plate_ocr"
Version_id=str(1)
def infer_triton_trt(image):
    d4_image = np.expand_dims(image,0)
    with grpclient.InferenceServerClient(OCR_ip) as client:
        input0_data = d4_image
#         print(input0_data.shape)
        #
        inputs = [
                 grpclient.InferInput("input0", input0_data.shape,
                                np_to_triton_dtype(input0_data.dtype)),
        ]

        inputs[0].set_data_from_numpy(input0_data)
        #
        outputs = [
              grpclient.InferRequestedOutput("output0"),
        ]

        response = client.infer(model_name,
                                inputs,
                                request_id=Version_id,
                                outputs=outputs)

        result = response.get_response()
        output0_data = response.as_numpy("Output")
        return output0_data

class TrtCtcOcr(object):

    def __init__(self, input_shape= (480,16)):
        self.input_shape = input_shape
        self.converter = utils.strLabelConverter(alphabets.alphabet)
      
    def preprocess(self, img_raw):
        img = cv2.resize(img_raw,self.input_shape, interpolation=cv2.INTER_LINEAR)
        img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)                    
        img = np.float32(img)/255
        img = np.subtract(img,0.5)/0.5
        img = np.expand_dims(img,0)
        return img

    def postprocess(self, infer_results):
        preds = torch.from_numpy(infer_results).reshape(30,1,38)
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        preds_size = Variable(torch.LongTensor([preds.size(0)]))
        sim_pred = self.converter.decode(preds.data, preds_size.data, raw=False)
        return sim_pred
    
    def recognition(self, img):
        input_img = self.preprocess(img)
        infer_result=infer_triton_trt(input_img)        
        return self.postprocess(infer_result)
