import numpy as np
import cv2
from itertools import product as product
from math import ceil

# triton client
from tritonclient.utils import *
#import tritonclient.http as httpclient

import tritonclient.grpc as grpcclient
import tritonclient.utils.shared_memory as shm

import time

#from .ipconfig import LPD_ip
LPD_ip="127.0.0.1:8001"

Model_name="RetinaPlate320x320"

Version_id=str(1)

gray_mean = 84

def _pad_to_square(image, rgb_mean):
    height, width = image.shape[:2]
    long_side = max(width, height)
    
    dim = image.shape[2]
    
    image_t = np.empty((long_side, long_side, dim), dtype=image.dtype)
    image_t[:, :] = rgb_mean
    image_t[0:0 + height, 0:0 + width] = image
    return image_t

def _preprocess_trt(img, shape=(320, 320), is_pading = False):
    """Preprocess an image before TRT SSD inferencing."""
    img=cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
    if(is_pading):
        img = np.expand_dims(img, axis=2)

        img = _pad_to_square(img, gray_mean)

    img = cv2.resize(img, shape, interpolation=cv2.INTER_LINEAR)

    img = np.expand_dims(img, axis=2)

    img = img.astype(np.float32)
    img -= gray_mean
    img = img.transpose(2, 0, 1)

    return img


# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """

    boxes = np.concatenate((
        priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
        priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])), 1)
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]
    return boxes

def decode_landm(pre, priors, variances):
    """Decode landm from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        pre (tensor): landm predictions for loc layers,
            Shape: [num_priors,10]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded landm predictions
    """
    landms = np.concatenate((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
                        #priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
                        ), 1)
    return landms

def _postprocess_trt(img, output, conf_th, output_layout=7):
    """Postprocess TRT SSD output."""
    img_h, img_w, _ = img.shape
    boxes, confs, clss = [], [], []
    for prefix in range(0, len(output), output_layout):
        #index = int(output[prefix+0])
        conf = float(output[prefix+2])
        if conf < conf_th:
            continue
        x1 = int(output[prefix+3] * img_w)
        y1 = int(output[prefix+4] * img_h)
        x2 = int(output[prefix+5] * img_w)
        y2 = int(output[prefix+6] * img_h)
        cls = int(output[prefix+1])
        boxes.append((x1, y1, x2, y2))
        confs.append(conf)
        clss.append(cls)
    return boxes, confs, clss


def generate_priors(image_size=(640,640), clip = False):
    min_sizes = [[16, 32], [64, 128], [256, 512]]
    steps = [8, 16, 32]

    feature_maps = [[ceil(image_size[0]/step), ceil(image_size[1]/step)] for step in steps]
    
    anchors = []
    for k, f in enumerate(feature_maps):
        for i, j in product(range(f[0]), range(f[1])):
            for min_size in min_sizes[k]:
                s_kx = min_size / image_size[1]
                s_ky = min_size / image_size[0]
                dense_cx = [x * steps[k] / image_size[1] for x in [j + 0.5]]
                dense_cy = [y * steps[k] / image_size[0] for y in [i + 0.5]]
                for cy, cx in product(dense_cy, dense_cx):
                    anchors += [cx, cy, s_kx, s_ky]
                    
    print("priors nums:{}".format(len(anchors)))
    anchors = np.array(anchors)
    anchors = anchors.reshape((-1,4))
    if clip:
        np.clamp(anchors, 0.0, 1.0, out=anchors)
    return anchors

def py_cpu_nms(dets, thresh):
    """Pure Python NMS baseline."""
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]

    return keep


class TrtRetinaPlateSHM(object):
    def __init__(self, input_shape, debug = False, do_unregister = False):
        """Initialize TensorRT plugins, engine and conetxt."""
        self.input_shape = input_shape
        self.priors = generate_priors(input_shape)
        self.variances = [0.1, 0.2]
        
        self.client = grpcclient.InferenceServerClient(LPD_ip, debug)
        
        # 註冊 shared memort
        # To make sure no shared memory regions are registered with the
        # server.
        if(do_unregister):
            self.client.unregister_system_shared_memory()
            self.client.unregister_cuda_shared_memory()
        
        self.input_byte_size = 1 * self.input_shape[0] * self.input_shape[1] * 4
        self.output_byte_size = self.priors.shape[0] * self.priors.shape[1]
        
        # Create shared memory region for output and store shared memory handle
        self.shm_class_handle = shm.create_shared_memory_region("retinaplate_class_data",
                                                        "/retinaplate_class_simple",
                                                        self.output_byte_size * 2)

        # Register shared memory region for outputs with Triton Server
        self.client.register_system_shared_memory("retinaplate_class_data", "/retinaplate_class_simple",
                                                    self.output_byte_size * 2)
        
        # Create shared memory region for output and store shared memory handle
        self.shm_location_handle = shm.create_shared_memory_region("retinaplate_location_data",
                                                        "/retinaplate_location_simple",
                                                        self.output_byte_size * 4)

        # Register shared memory region for outputs with Triton Server
        self.client.register_system_shared_memory("retinaplate_location_data","/retinaplate_location_simple",self.output_byte_size * 4)
        
        
        # Create shared memory region for output and store shared memory handle
        self.shm_landmark_handle = shm.create_shared_memory_region("retinaplate_landmark_data",
                                                        "/retinaplate_landmark_simple",
                                                        self.output_byte_size * 8)

        # Register shared memory region for outputs with Triton Server
        self.client.register_system_shared_memory("retinaplate_landmark_data","/retinaplate_landmark_simple",self.output_byte_size * 8)
        
        
        # Create shared memory region for input and store shared memory handle
        self.shm_ip_handle = shm.create_shared_memory_region("retinaplate_input_data",
                                                        "/retinaplate_input_simple",
                                                        self.input_byte_size)
        
        # Register shared memory region for inputs with Triton Server
        self.client.register_system_shared_memory("retinaplate_input_data", "/retinaplate_input_simple",
                                                        self.input_byte_size)

        
    def infer_triton_trt(self, img_raw):
        # Put input data values into shared memory
        shm.set_shared_memory_region(self.shm_ip_handle, [img_raw])
        
        input0_data = np.expand_dims(img_raw, axis=0)
        inputs = [
            grpcclient.InferInput("Input",
                                 input0_data.shape,
                                 np_to_triton_dtype(input0_data.dtype))
        ]
        inputs[0].set_shared_memory("retinaplate_input_data", self.input_byte_size)


        outputs = [
                grpcclient.InferRequestedOutput("Class"),
                grpcclient.InferRequestedOutput("Location"),
                grpcclient.InferRequestedOutput("LandMark")
            ]

        outputs[0].set_shared_memory("retinaplate_class_data", self.output_byte_size*2)
        outputs[1].set_shared_memory("retinaplate_location_data", self.output_byte_size*4)
        outputs[2].set_shared_memory("retinaplate_landmark_data", self.output_byte_size*8)
        
        
        
        results = self.client.infer(model_name=Model_name,
                                  inputs=inputs,
                                  outputs=outputs)
        
        class_data = None
        location_data = None
        landmark_data = None
        
        class_results = results.get_output("Class")
        location_results = results.get_output("Location")
        landmark_results = results.get_output("LandMark")
        
        if class_results is not None:
            class_data = shm.get_contents_as_numpy(
                self.shm_class_handle, triton_to_np_dtype(class_results.datatype),
                class_results.shape)
            
        if location_results is not None:
            location_data = shm.get_contents_as_numpy(
                self.shm_location_handle, triton_to_np_dtype(location_results.datatype),
                location_results.shape)
            
        if landmark_results is not None:
            landmark_data = shm.get_contents_as_numpy(
                self.shm_landmark_handle, triton_to_np_dtype(landmark_results.datatype),
                landmark_results.shape)
      
        return class_data,location_data,landmark_data
        
    def detect(self, img, conf_th=0.3, top_k=30, nms_threshold=0.01 ,keep_top_k=200):
        """Detect objects in the input image."""
        im_height, im_width= img.shape[:2]
        
        img_resized = _preprocess_trt(img, self.input_shape)
        
        #stime = time.time()
        cla,bbox,land = self.infer_triton_trt(img_resized)
        #print("RetinaPlate 320*320 shared memory inference time", time.time() - stime)
        
#         output = self.host_outputs[0]
        loc = bbox.reshape((-1, 4))
        land = land.reshape((-1, 8))
        conf = cla.reshape((-1, 2))
        
        loc = decode(loc, self.priors, self.variances)
        land = decode_landm(land, self.priors, self.variances)
        scores = conf[:, 1]
        
        # ignore low scores
        inds = np.where(scores > conf_th)[0]
        scores = scores[inds]
        loc = loc[inds]
        land = land[inds]
        # keep top-K before NMS
        order = scores.argsort()[::-1][:top_k]
        scores = scores[order]
        loc = loc[order]
        land = land[order]
        
         # do NMS
        dets = np.hstack((loc, scores[:, np.newaxis])).astype(np.float32, copy=False)
        keep = py_cpu_nms(dets, nms_threshold)
        dets = dets[keep, :]
        land = land[keep]

        # keep top-K faster NMS
        dets = dets[:keep_top_k, :]
        land = land[:keep_top_k, :]
        
        scale = np.array([im_width, im_height , im_width, im_height])
        scale1 = np.array([im_width, im_height, im_width, im_height,
                   im_width, im_height, im_width, im_height])
        
        dets[:,:4] = dets[:,:4] * scale
        land = land * scale1
        
        return np.concatenate((dets, land), axis=1)
#         return _postprocess_trt(img, output, conf_th)