import numpy as np
import cv2
# triton client
import sys
sys.path.append("..")
#from ..ipconfig import Yolo_ip
from tritonclient.utils import *

import tritonclient.grpc as grpcclient
from tritonclient import utils
import tritonclient.utils.shared_memory as shm

import time

Yolo_ip="127.0.0.1:8001"
# Model_name="yolov4-tiny-lpd"
Model_name="yolov4-tiny-3l-lpd"
Version_id=str(1)

def _preprocess_yolo(img, input_shape, letter_box=False):
    """Preprocess an image before TRT YOLO inferencing.

    # Args
        img: int8 numpy array of shape (img_h, img_w, 3)
        input_shape: a tuple of (H, W)
        letter_box: boolean, specifies whether to keep aspect ratio and
                    create a "letterboxed" image for inference

    # Returns
        preprocessed img: float32 numpy array of shape (3, H, W)
    """
    if letter_box:
        img_h, img_w, _ = img.shape
        new_h, new_w = input_shape[0], input_shape[1]
        offset_h, offset_w = 0, 0
        if (new_w / img_w) <= (new_h / img_h):
            new_h = int(img_h * new_w / img_w)
            offset_h = (input_shape[0] - new_h) // 2
        else:
            new_w = int(img_w * new_h / img_h)
            offset_w = (input_shape[1] - new_w) // 2
        resized = cv2.resize(img, (new_w, new_h))
        img = np.full((input_shape[0], input_shape[1], 3), 127, dtype=np.uint8)
        img[offset_h:(offset_h + new_h), offset_w:(offset_w + new_w), :] = resized
    else:
        img = cv2.resize(img, (input_shape[1], input_shape[0]))

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.transpose((2, 0, 1)).astype(np.float32)
    img /= 255.0
    return img


def _nms_boxes(detections, nms_threshold):
    """Apply the Non-Maximum Suppression (NMS) algorithm on the bounding
    boxes with their confidence scores and return an array with the
    indexes of the bounding boxes we want to keep.

    # Args
        detections: Nx7 numpy arrays of
                    [[x, y, w, h, box_confidence, class_id, class_prob],
                     ......]
    """
    x_coord = detections[:, 0]
    y_coord = detections[:, 1]
    width = detections[:, 2]
    height = detections[:, 3]
    box_confidences = detections[:, 4] * detections[:, 6]

    areas = width * height
    ordered = box_confidences.argsort()[::-1]

    keep = list()
    while ordered.size > 0:
        # Index of the current element:
        i = ordered[0]
        keep.append(i)
        xx1 = np.maximum(x_coord[i], x_coord[ordered[1:]])
        yy1 = np.maximum(y_coord[i], y_coord[ordered[1:]])
        xx2 = np.minimum(x_coord[i] + width[i], x_coord[ordered[1:]] + width[ordered[1:]])
        yy2 = np.minimum(y_coord[i] + height[i], y_coord[ordered[1:]] + height[ordered[1:]])

        width1 = np.maximum(0.0, xx2 - xx1 + 1)
        height1 = np.maximum(0.0, yy2 - yy1 + 1)
        intersection = width1 * height1
        union = (areas[i] + areas[ordered[1:]] - intersection)
        iou = intersection / union
        indexes = np.where(iou <= nms_threshold)[0]
        ordered = ordered[indexes + 1]

    keep = np.array(keep)
    return keep


def _postprocess_yolo(trt_outputs, img_w, img_h, conf_th, nms_threshold,
                      input_shape, letter_box=False):
    """Postprocess TensorRT outputs.

    # Args
        trt_outputs: a list of 2 or 3 tensors, where each tensor
                    contains a multiple of 7 float32 numbers in
                    the order of [x, y, w, h, box_confidence, class_id, class_prob]
        conf_th: confidence threshold
        letter_box: boolean, referring to _preprocess_yolo()

    # Returns
        boxes, scores, classes (after NMS)
    """
    # filter low-conf detections and concatenate results of all yolo layers
    detections = []
    for o in trt_outputs:
        dets = o.reshape((-1, 7))
        dets = dets[dets[:, 4] * dets[:, 6] >= conf_th]
        detections.append(dets)
    detections = np.concatenate(detections, axis=0)

    if len(detections) == 0:
        boxes = np.zeros((0, 4), dtype=np.int)
        scores = np.zeros((0,), dtype=np.float32)
        classes = np.zeros((0,), dtype=np.float32)
    else:
        box_scores = detections[:, 4] * detections[:, 6]

        # scale x, y, w, h from [0, 1] to pixel values
        old_h, old_w = img_h, img_w
        offset_h, offset_w = 0, 0
        if letter_box:
            if (img_w / input_shape[1]) >= (img_h / input_shape[0]):
                old_h = int(input_shape[0] * img_w / input_shape[1])
                offset_h = (old_h - img_h) // 2
            else:
                old_w = int(input_shape[1] * img_h / input_shape[0])
                offset_w = (old_w - img_w) // 2
        detections[:, 0:4] *= np.array(
            [old_w, old_h, old_w, old_h], dtype=np.float32)

        # NMS
        nms_detections = np.zeros((0, 7), dtype=detections.dtype)
        for class_id in set(detections[:, 5]):
            idxs = np.where(detections[:, 5] == class_id)
            cls_detections = detections[idxs]
            keep = _nms_boxes(cls_detections, nms_threshold)
            nms_detections = np.concatenate(
                [nms_detections, cls_detections[keep]], axis=0)

        xx = nms_detections[:, 0].reshape(-1, 1)
        yy = nms_detections[:, 1].reshape(-1, 1)
        if letter_box:
            xx = xx - offset_w
            yy = yy - offset_h
        ww = nms_detections[:, 2].reshape(-1, 1)
        hh = nms_detections[:, 3].reshape(-1, 1)
        boxes = np.concatenate([xx, yy, xx+ww, yy+hh], axis=1) + 0.5
        boxes = boxes.astype(np.int)
        scores = nms_detections[:, 4] * nms_detections[:, 6]
        classes = nms_detections[:, 5]
    return boxes, scores, classes


class TrtYOLOSHM(object):
    def __init__(self, category_num=6, letter_box=False, debug=False):
        self.category_num = category_num
        self.letter_box = letter_box
#         self.input_shape = (416,416)
        self.input_shape = (320,320)
        
        self.client = grpcclient.InferenceServerClient(Yolo_ip, debug)
        
        # 註冊 shared memort
        # To make sure no shared memory regions are registered with the
        # server.
        self.client.unregister_system_shared_memory()
        self.client.unregister_cuda_shared_memory()
        
        self.input_byte_size = 3 * self.input_shape[0] * self.input_shape[1] * 4
        
        # 416*416 yolov4-tiny
#         self.output_byte_size = 17745 * 4
        
        # 320*320 yolov4-tiny-3l
        self.output_byte_size = 44100 * 4
        # 416*416 yolov4-tiny-3l
#         self.output_byte_size = 74529 * 4

        # Create shared memory region for output and store shared memory handle
        self.shm_op_handle = shm.create_shared_memory_region("output_data",
                                                        "/output_simple",
                                                        self.output_byte_size)

        # Register shared memory region for outputs with Triton Server
        self.client.register_system_shared_memory("output_data", "/output_simple",
                                                    self.output_byte_size)

        # Create shared memory region for input and store shared memory handle
        self.shm_ip_handle = shm.create_shared_memory_region("input_data",
                                                        "/input_simple",
                                                        self.input_byte_size)
        
        # Register shared memory region for inputs with Triton Server
        self.client.register_system_shared_memory("input_data", "/input_simple",
                                                        self.input_byte_size)
        
    def inference_fn(self, img_raw):
        # Put input data values into shared memory
        shm.set_shared_memory_region(self.shm_ip_handle, [img_raw])
        
        input0_data = np.expand_dims(img_raw, axis=0)
        inputs = [
            grpcclient.InferInput("input",
                                 input0_data.shape,
                                 np_to_triton_dtype(input0_data.dtype))
        ]
        inputs[0].set_shared_memory("input_data", self.input_byte_size)



        outputs = [
            grpcclient.InferRequestedOutput("detections")
        ]

        outputs[0].set_shared_memory("output_data", self.output_byte_size)
        
        results = self.client.infer(model_name=Model_name,
                                  inputs=inputs,
                                  outputs=outputs)
        detections = results.get_output("detections")
        
        if detections is not None:
            detections_data = shm.get_contents_as_numpy(
                self.shm_op_handle, utils.triton_to_np_dtype(detections.datatype),
                detections.shape)
            return detections_data
        else:
            return None
    
    def detect(self, img, conf_th=0.3, letter_box=None):
        """Detect objects in the input image."""
        letter_box = self.letter_box if letter_box is None else letter_box
        img_resized = _preprocess_yolo(img, self.input_shape, letter_box)
        
#         stime = time.time()
        trt_outputs = self.inference_fn(img_resized)
#         print(time.time() - stime)
        
        boxes, scores, classes = _postprocess_yolo(
            trt_outputs, img.shape[1], img.shape[0], conf_th,
            nms_threshold=0.5, input_shape=self.input_shape,
            letter_box=letter_box)
        
        # clip x1, y1, x2, y2 within original image
        boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, img.shape[1]-1)
        boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, img.shape[0]-1)
#         return boxes, scores, classes
        return zip(classes,boxes)
        
        
