import math

import MNN
import cv2
import numpy as np
import torch

class FaceDetector:
    def __init__(self,model_path="../model/version-slim/slim-320.mnn",input_size=(320,240)):
        self.input_size = input_size
        self.image_mean = np.array([127, 127, 127])
        self.image_std = 128.0
        self.iou_threshold = 0.3
        self.threshold = 0.7
        self.center_variance = 0.1
        self.size_variance = 0.2
        self.min_boxes = [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]]
        self.strides = [8, 16, 32, 64]
        self.priors = self.define_img_size(self.input_size)
        self.interpreter = MNN.Interpreter(model_path)
        self.session = self.interpreter.createSession({'numThread':4})
        self.input_tensor = self.interpreter.getSessionInput(self.session)
        
    def predict(self,image_ori):
        image = cv2.resize(image_ori, self.input_size)
        image = (image - self.image_mean) / self.image_std
        image = image.transpose((2, 0, 1))
        image = image.astype(np.float32)
        tmp_input = MNN.Tensor((1, 3, self.input_size[1], self.input_size[0]), MNN.Halide_Type_Float, image, MNN.Tensor_DimensionType_Caffe)
        self.input_tensor.copyFrom(tmp_input)
        self.interpreter.runSession(self.session)
        scores = self.interpreter.getSessionOutput(self.session, "scores").getData()
        boxes = self.interpreter.getSessionOutput(self.session, "boxes").getData()
        boxes = np.expand_dims(np.reshape(boxes, (-1, 4)), axis=0)
        scores = np.expand_dims(np.reshape(scores, (-1, 2)), axis=0)
        boxes = self.convert_locations_to_boxes(boxes, self.priors, self.center_variance, self.size_variance)
        boxes = self.center_form_to_corner_form(boxes)
        boxes, labels, probs = self.decode(image_ori.shape[1], image_ori.shape[0], scores, boxes, self.threshold)
        return boxes, labels, probs
        
    def define_img_size(self,image_size):
        shrinkage_list = []
        feature_map_w_h_list = []
        for size in image_size:
            feature_map = [math.ceil(size / stride) for stride in self.strides]
            feature_map_w_h_list.append(feature_map)

        for i in range(0, len(image_size)):
            shrinkage_list.append(self.strides)
        priors = self.generate_priors(feature_map_w_h_list, shrinkage_list, image_size, self.min_boxes)
        return priors

    def generate_priors(self,feature_map_list, shrinkage_list, image_size, min_boxes, clamp=True):
        priors = []
        for index in range(0, len(feature_map_list[0])):
            scale_w = image_size[0] / shrinkage_list[0][index]
            scale_h = image_size[1] / shrinkage_list[1][index]
            for j in range(0, feature_map_list[1][index]):
                for i in range(0, feature_map_list[0][index]):
                    x_center = (i + 0.5) / scale_w
                    y_center = (j + 0.5) / scale_h

                    for min_box in min_boxes[index]:
                        w = min_box / image_size[0]
                        h = min_box / image_size[1]
                        priors.append([
                            x_center,
                            y_center,
                            w,
                            h
                        ])
        print("priors nums:{}".format(len(priors)))
        priors = torch.tensor(priors)
        if clamp:
            torch.clamp(priors, 0.0, 1.0, out=priors)
        return priors

    def decode(self,width, height, confidences, boxes, prob_threshold, iou_threshold=0.3, top_k=-1):
        boxes = boxes[0]
        confidences = confidences[0]
        picked_box_probs = []
        picked_labels = []
        for class_index in range(1, confidences.shape[1]):
            probs = confidences[:, class_index]
            mask = probs > prob_threshold
            probs = probs[mask]
            if probs.shape[0] == 0:
                continue
            subset_boxes = boxes[mask, :]
            box_probs = np.concatenate([subset_boxes, probs.reshape(-1, 1)], axis=1)
            box_probs = self.hard_nms(box_probs,
                                           iou_threshold=iou_threshold,
                                           top_k=top_k,
                                           )
            picked_box_probs.append(box_probs)
            picked_labels.extend([class_index] * box_probs.shape[0])
        if not picked_box_probs:
            return np.array([]), np.array([]), np.array([])
        picked_box_probs = np.concatenate(picked_box_probs)
        picked_box_probs[:, 0] *= width
        picked_box_probs[:, 1] *= height
        picked_box_probs[:, 2] *= width
        picked_box_probs[:, 3] *= height
        return picked_box_probs[:, :4].astype(np.int32), np.array(picked_labels), picked_box_probs[:, 4]
    def center_form_to_corner_form(self,locations):
        return np.concatenate([locations[..., :2] - locations[..., 2:] / 2,
                           locations[..., :2] + locations[..., 2:] / 2], len(locations.shape) - 1)
    def convert_locations_to_boxes(self,locations, priors, center_variance,
                               size_variance):
        """Convert regressional location results of SSD into boxes in the form of (center_x, center_y, h, w).

        The conversion:
            $$predicted\_center * center_variance = \frac {real\_center - prior\_center} {prior\_hw}$$
            $$exp(predicted\_hw * size_variance) = \frac {real\_hw} {prior\_hw}$$
        We do it in the inverse direction here.
        Args:
            locations (batch_size, num_priors, 4): the regression output of SSD. It will contain the outputs as well.
            priors (num_priors, 4) or (batch_size/1, num_priors, 4): prior boxes.
            center_variance: a float used to change the scale of center.
            size_variance: a float used to change of scale of size.
        Returns:
            boxes:  priors: [[center_x, center_y, h, w]]. All the values
                are relative to the image size.
        """
        # priors can have one dimension less.
        if len(priors.shape) + 1 == len(locations.shape):
            priors = np.expand_dims(priors, 0)
        return np.concatenate([
            locations[..., :2] * center_variance * priors[..., 2:] + priors[..., :2],
            np.exp(locations[..., 2:] * size_variance) * priors[..., 2:]
        ], axis=len(locations.shape) - 1)

    def area_of(self,left_top, right_bottom):
        """Compute the areas of rectangles given two corners.

        Args:
            left_top (N, 2): left top corner.
            right_bottom (N, 2): right bottom corner.

        Returns:
            area (N): return the area.
        """
        hw = np.clip(right_bottom - left_top, 0.0, None)
        return hw[..., 0] * hw[..., 1]
    def iou_of(self,boxes0, boxes1, eps=1e-5):
        """Return intersection-over-union (Jaccard index) of boxes.

        Args:
            boxes0 (N, 4): ground truth boxes.
            boxes1 (N or 1, 4): predicted boxes.
            eps: a small number to avoid 0 as denominator.
        Returns:
            iou (N): IoU values.
        """
        overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
        overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])

        overlap_area = self.area_of(overlap_left_top, overlap_right_bottom)
        area0 = self.area_of(boxes0[..., :2], boxes0[..., 2:])
        area1 = self.area_of(boxes1[..., :2], boxes1[..., 2:])
        return overlap_area / (area0 + area1 - overlap_area + eps)
    def hard_nms(self,box_scores, iou_threshold, top_k=-1, candidate_size=200):
        """

        Args:
            box_scores (N, 5): boxes in corner-form and probabilities.
            iou_threshold: intersection over union threshold.
            top_k: keep top_k results. If k <= 0, keep all the results.
            candidate_size: only consider the candidates with the highest scores.
        Returns:
             picked: a list of indexes of the kept boxes
        """
        scores = box_scores[:, -1]
        boxes = box_scores[:, :-1]
        picked = []
        # _, indexes = scores.sort(descending=True)
        indexes = np.argsort(scores)
        # indexes = indexes[:candidate_size]
        indexes = indexes[-candidate_size:]
        while len(indexes) > 0:
            # current = indexes[0]
            current = indexes[-1]
            picked.append(current)
            if 0 < top_k == len(picked) or len(indexes) == 1:
                break
            current_box = boxes[current, :]
            # indexes = indexes[1:]
            indexes = indexes[:-1]
            rest_boxes = boxes[indexes, :]
            iou = self.iou_of(
                rest_boxes,
                np.expand_dims(current_box, axis=0),
            )
            indexes = indexes[iou <= iou_threshold]

        return box_scores[picked, :]