import torch
from torch.autograd import Variable

from . import ocr_utils as utils
from . import alphabets

import numpy as np
import cv2

from tritonclient.utils import *
import tritonclient.grpc as grpclient
#from .ipconfig import OCR_ip
OCR_ip = "127.0.0.1:8001"
model_name = "resnet32_ctc_plate_ocr"
Version_id=str(1)


class TrtCtcOcr(object):
        def __init__(self,model=model_name):            
            self.model_name=model
            self.Version_id=str(1)
            self.size=(480, 16)
            alphabet ="""ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-"""
            self.converter = utils.strLabelConverter(alphabet)
            
        def preprocess(self,img_raw):
            img = cv2.resize(img_raw,self.size, interpolation=cv2.INTER_LINEAR)
            img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)                    
            img = np.float32(img)/255
            img = np.subtract(img,0.5)/0.5
            img = np.expand_dims(img,0)
            return img
        
        def postprocess(self,infer_results):
            preds = torch.from_numpy(infer_results).transpose(0,1)
            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            preds_size = Variable(torch.LongTensor([preds.size(0)]))
            sim_pred = self.converter.decode(preds.data, preds_size.data, raw=False)
            return sim_pred
        
        def recognition(self,image):
            image = self.preprocess(image)
            d4_image = np.expand_dims(image,0)
            with grpclient.InferenceServerClient(OCR_ip) as client:
                input0_data = d4_image
#                 print(input0_data.shape)
                #
                inputs = [
                         grpclient.InferInput("input0", input0_data.shape,
                                        np_to_triton_dtype(input0_data.dtype)),
                ]

                inputs[0].set_data_from_numpy(input0_data)
                #
                outputs = [
                      grpclient.InferRequestedOutput("output0"),
                ]

                response = client.infer(self.model_name,
                                        inputs,
                                        request_id=self.Version_id,
                                        outputs=outputs)

                result = response.get_response()
                output0_data = response.as_numpy("output0")
                plate= self.postprocess(output0_data)
#                 print(output0_data.shape)
#                 plate = output0_data[0][0]
                return plate
