import sys
import os
sys.path.append(os.getcwd())
import argparse
import torch
from model.MTCNN_nets import PNet, ONet
import math
import numpy as np
from utils.util import *
import cv2
import time

class MTCNN_main():
    def __init__(self,dev=None,p_model_path=None,o_model_path=None):
        self.pnet=None
        self.onet=None
        self.device=None
        self.create_mtcnn_net(dev, p_model_path, o_model_path)
    def create_mtcnn_net(self,dev, p_model_path=None, o_model_path=None):
        self.device=dev
        if p_model_path is not None:
            self.pnet = PNet().to(self.device)
            self.pnet.load_state_dict(torch.load(p_model_path, map_location=lambda storage, loc: storage))
            self.pnet.eval()
            print("Pnet loaded!!")
        if o_model_path is not None:
            self.onet = ONet().to(self.device)
            self.onet.load_state_dict(torch.load(o_model_path, map_location=lambda storage, loc: storage))
            self.onet.eval()
            print("Onet loaded!!")


    def detect_plate(self,image, mini_lp_size,is_train=False):
        bboxes = np.array([])
        landmarks=np.array([])
        if not (self.pnet is None):
            bboxes = self.detect_pnet(self.pnet, image, mini_lp_size, self.device)
        else:
            print("Pnet not loaded")
        if not (is_train):
            if not (self.onet is None):
                bboxes,landmarks = self.detect_onet(self.onet, image, bboxes, self.device)
            else:
                print("Onet not loaded")
        else:
            return bboxes
        return bboxes,landmarks    


    def detect_pnet(self,pnet, image, min_lp_size, device):

        # start = time.time()
        thresholds = 0.6 # lp detection thresholds
        nms_thresholds = 0.7

        # BUILD AN IMAGE PYRAMID
        height, width, channel = image.shape
        min_height, min_width = height, width

        factor = 0.707  # sqrt(0.5)

        # scales for scaling the image
        scales = []

        factor_count = 0
        while min_height > min_lp_size[1] and min_width > min_lp_size[0]:
            scales.append(factor ** factor_count)
            min_height *= factor
            min_width *=factor
            factor_count += 1

        # it will be returned
        bounding_boxes = []
        total=0
        with torch.no_grad():
            # run P-Net on different scales
            for scale in scales:
                stime=time.time()            
                sw, sh = math.ceil(width * scale), math.ceil(height * scale)
                img = cv2.resize(image, (sw, sh), interpolation=cv2.INTER_LINEAR)      
                img = torch.FloatTensor(preprocess(img)).to(device)
                landmraks,offset, prob = pnet(img)
                probs = prob.cpu().data.numpy()[0, 1, :, :]  # probs: probability of a face at each sliding window
                offsets = offset.cpu().data.numpy()  # offsets: transformations to true bounding boxes
                total+=time.time()-stime
#                 print(np.shape(offsets),np.shape(probs))
    #             print("in_pnet1:",time.time()-stime)

                # applying P-Net is equivalent, in some sense, to moving 12x12 window with stride 2
                stride, cell_size = (2,5), (12,44)
                # indices of boxes where there is probably a lp
                # returns a tuple with an array of row idx's, and an array of col idx's:
                inds = np.where(probs > thresholds)
    #             print(np.shape(inds))

                if inds[0].size == 0:
                    boxes = None
                else:
                    # transformations of bounding boxes
                    tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)]
                    offsets = np.array([tx1, ty1, tx2, ty2])
                    score = probs[inds[0], inds[1]]
                    # P-Net is applied to scaled images
                    # so we need to rescale bounding boxes back
                    bounding_box = np.vstack([
                        np.round((stride[1] * inds[1] + 1.0) / scale),
                        np.round((stride[0] * inds[0] + 1.0) / scale),
                        np.round((stride[1] * inds[1] + 1.0 + cell_size[1]) / scale),
                        np.round((stride[0] * inds[0] + 1.0 + cell_size[0]) / scale),
                        score, offsets])
                    boxes = bounding_box.T
                    keep = nms(boxes[:, 0:5], overlap_threshold=0.5)
                    boxes[keep]

                bounding_boxes.append(boxes)

            # collect boxes (and offsets, and scores) from different scales
            bounding_boxes = [i for i in bounding_boxes if i is not None]

            if bounding_boxes != []:
                bounding_boxes = np.vstack(bounding_boxes)
                keep = nms(bounding_boxes[:, 0:5], nms_thresholds)
                bounding_boxes = bounding_boxes[keep]
            else:
                bounding_boxes = np.zeros((1,9))
            # use offsets predicted by pnet to transform bounding boxes
            bboxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
            # shape [n_boxes, 5],  x1, y1, x2, y2, score

            bboxes[:, 0:4] = np.round(bboxes[:, 0:4])

            # print("pnet predicted in {:2.3f} seconds".format(time.time() - start))
    #         print("in_pnet2:",total)

            return bboxes

    def detect_onet(self,onet, image, bboxes, device):
        # start = time.time()
        size = (94,24)
        thresholds = 0.8  # face detection thresholds
        nms_thresholds = 0.7
        image=cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
        image=np.expand_dims(image,2)
        height, width, channel = image.shape 
        num_boxes = len(bboxes)
        [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bboxes, width, height)

#         img_boxes = np.zeros((num_boxes, 3, size[1], size[0]))
        img_boxes = np.zeros((num_boxes, 1, size[1], size[0]))
    
        for i in range(num_boxes):
#             img_box = np.zeros((h[i], w[i], 3))
            img_box = np.zeros((h[i], w[i], 1))            
            img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \
                image[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1),:]
            # resize
            img_box = cv2.resize(img_box, size, interpolation=cv2.INTER_LINEAR)
            img_boxes[i, :, :, :] = preprocess(img_box,True)

        img_boxes = torch.FloatTensor(img_boxes).to(device)
    #     print("onet_img:",np.shape(img_boxes))
        stime=time.time()    
        landmarks,offset, prob = onet(img_boxes)
    #     print("onet_time:",time.time()-stime)
        offsets = offset.cpu().data.numpy()  # shape [n_boxes, 4]
        probs = prob.cpu().data.numpy()  # shape [n_boxes, 2]
        landmarks=landmarks.cpu().numpy()

        keep = np.where(probs[:, 1] > thresholds)[0]
        bboxes = bboxes[keep]
        bboxes[:, 4] = probs[keep, 1].reshape((-1,))  # assign score from stage 2    


        offsets = offsets[keep]
        landmarks=landmarks[keep]

         # compute landmark points
        width = bboxes[:, 2]- bboxes[:, 0] + 1.0
        height = bboxes[:, 3]- bboxes[:, 1] + 1.0
        xmin, ymin = bboxes[:, 0], bboxes[:, 1]

        landmarks[:, 0:1]=np.expand_dims(xmin, 1) + np.expand_dims(width, 1)*landmarks[:, 0:1]
        landmarks[:, 1:2]=np.expand_dims(ymin, 1) + np.expand_dims(height, 1)*landmarks[:, 1:2]
        landmarks[:, 2:3]=np.expand_dims(xmin, 1) + np.expand_dims(width, 1)*landmarks[:, 2:3]
        landmarks[:, 3:4]=np.expand_dims(ymin, 1) + np.expand_dims(height, 1)*landmarks[:, 3:4]
        landmarks[:, 4:5]=np.expand_dims(xmin, 1) + np.expand_dims(width, 1)*landmarks[:, 4:5]
        landmarks[:, 5:6]=np.expand_dims(ymin, 1) + np.expand_dims(height, 1)*landmarks[:, 5:6]
        landmarks[:, 6:7]=np.expand_dims(xmin, 1) + np.expand_dims(width, 1)*landmarks[:, 6:7]
        landmarks[:, 7:8]=np.expand_dims(ymin, 1) + np.expand_dims(height, 1)*landmarks[:, 7:8]   


        bboxes = calibrate_box(bboxes, offsets)
        keep = nms(bboxes, nms_thresholds, mode='min')
        bboxes = bboxes[keep]
        landmarks=landmarks[keep]
        bboxes[:, 0:4] = np.round(bboxes[:, 0:4])
        # print("onet predicted in {:2.3f} seconds".format(time.time() - start))

        return bboxes,landmarks