import sys
sys.path.append("Pytorch_Retinaface")
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from .data import cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from utils.nms.py_cpu_nms import py_cpu_nms
import cv2
from .models.retinaface import RetinaFace
from utils.box_utils import decode, decode_landm
from Pytorch_Retinaface import args
import time


def check_keys(model, pretrained_state_dict):
    ckpt_keys = set(pretrained_state_dict.keys())
    model_keys = set(model.state_dict().keys())
    used_pretrained_keys = model_keys & ckpt_keys
    unused_pretrained_keys = ckpt_keys - model_keys
    missing_keys = model_keys - ckpt_keys
    print('Missing keys:{}'.format(len(missing_keys)))
    print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
    print('Used keys:{}'.format(len(used_pretrained_keys)))
    assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
    return True


def remove_prefix(state_dict, prefix):
    ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
    print('remove prefix \'{}\''.format(prefix))
    f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
    return {f(key): value for key, value in state_dict.items()}

def load_model(model, pretrained_path, load_to_cpu):
    print('Loading pretrained model from {}'.format(pretrained_path))
    if load_to_cpu:
        pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
    else:
        device = torch.cuda.current_device()
        pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
    if "state_dict" in pretrained_dict.keys():
        pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
    else:
        pretrained_dict = remove_prefix(pretrained_dict, 'module.')
    check_keys(model, pretrained_dict)
    model.load_state_dict(pretrained_dict, strict=False)
    return model



class RetinaPlate():
    def __init__(self,im_height,im_width,mode='torch'):
        torch.set_grad_enabled(False)
        self.cfg = None
        self.mode=mode
        if args.network == "mobile0.25":
            self.cfg = cfg_mnet
        elif args.network == "resnet50":
            self.cfg = cfg_re50
        # net and model
        if mode=='torch':
            self.net = RetinaFace(cfg=self.cfg, phase = 'test')
            self.net = load_model(self.net, args.trained_model, args.cpu)
            self.net.eval()
        elif mode=='onnx':
            import onnxruntime as ort
            self.net = ort.InferenceSession('./alpr/Models/test.onnx')
        elif mode=='trt':
            print("TRT enginge inference")
            self.cuda_ctx = None
            self.engine = None
        else:
            assert False,"Error: wrong setting mode"
            
            
        print('Finished loading model!')
#         print(self.net)
        cudnn.benchmark = True
        self.device = torch.device("cpu" if args.cpu else "cuda")
        if mode=='torch':
            self.net = self.net.to(self.device)
#             print(self.net)
        
        self.im_height=im_height
        self.im_width=im_width
        self.priorboxes=dict()
        priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
        priors = priorbox.forward()
        priors = priors.to(self.device)
        self.prior_data = priors.data
        self.priorboxes["%d:%d"%(im_height,im_width)]=priors.data
        
    def init_priorbox(self):
        item="%d:%d"%(self.im_height,self.im_width)
        self.prior_data = self.priorboxes.get(item)
        
        if self.prior_data ==None:
            #print(item)            
            priorbox = PriorBox(self.cfg, image_size=(self.im_height, self.im_width))
            priors = priorbox.forward()
            priors = priors.to(self.device)
            self.priorboxes["%d:%d"%(self.im_height,self.im_width)]=priors.data            
            self.prior_data = priors.data
    
    def detect_plate(self,img_raw):
        resize = 1
        img = np.float32(img_raw)
        im_height, im_width, _ = img.shape
        if im_height != self.im_height or im_width != self.im_width:
            self.im_height=im_height
            self.im_width=im_width
            self.init_priorbox()
        
        scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
        img -= (105, 110, 110)
        img = img.transpose(2, 0, 1)
        
        if self.mode=='onnx':            
            img = np.expand_dims(img, axis=0)
        elif self.mode=='torch':
            img = torch.from_numpy(img).unsqueeze(0)            
            img = img.to(self.device)
        elif self.mode=='trt':
            from retina_trt import trt_engine
            import pycuda.driver as cuda
            if self.cuda_ctx==None:
                self.cuda_ctx = cuda.Device(0).make_context()  # GPU 0
            if self.engine==None:
                self.engine=trt_engine()            
            img = np.expand_dims(img, axis=0)
            # Load data to the buffer
            self.engine.load_to_buffers(img)
            
        scale = scale.to(self.device)

#         tic = time.time()
        if self.mode=='torch':
#             print(img.size())
            with torch.no_grad():
                loc, conf, landms = self.net(img)  # forward pass
#             print(loc.size())
#             print(conf.size())
#             print(landms.size())
            
        elif self.mode=='onnx':
#             print(np.shape(img))            
            loc, conf, landms = self.net.run(None, {'input0': img})  # forward pass
#             print(np.shape(loc))
#             print(np.shape(conf))
#             print(np.shape(landms))

            # Conver to torch tensor cuda
            loc=torch.from_numpy(loc).to(self.device)        
            conf=torch.from_numpy(conf).to(self.device)
            landms=torch.from_numpy(landms).to(self.device)
            
        elif self.mode=='trt':
            loc,conf,landms=self.engine.go_inference()
#             print(conf)
            # Conver to torch tensor cuda
            loc=torch.from_numpy(loc).to(self.device)        
            conf=torch.from_numpy(conf).to(self.device)
            landms=torch.from_numpy(landms).to(self.device)

            
            
#         print('net forward time: {:.4f}'.format(time.time() - tic))
#         tic = time.time()

        boxes = decode(loc.data.squeeze(0), self.prior_data, self.cfg['variance'])            
        boxes = boxes * scale / resize
        boxes = boxes.cpu().numpy()
        scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
        landms = decode_landm(landms.data.squeeze(0), self.prior_data, self.cfg['variance'])
        scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
                               img.shape[3], img.shape[2], img.shape[3], img.shape[2],
                               ])
        scale1 = scale1.to(self.device)
        landms = landms * scale1 / resize
        landms = landms.cpu().numpy()

        # ignore low scores
        inds = np.where(scores > args.confidence_threshold)[0]
        boxes = boxes[inds]
        landms = landms[inds]
        scores = scores[inds]

        # keep top-K before NMS
        order = scores.argsort()[::-1][:args.top_k]
        boxes = boxes[order]
        landms = landms[order]
        scores = scores[order]

        # do NMS
        dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
        keep = py_cpu_nms(dets, args.nms_threshold)
        # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
        dets = dets[keep, :]
        landms = landms[keep]

        # keep top-K faster NMS
        dets = dets[:args.keep_top_k, :]
        landms = landms[:args.keep_top_k, :]

        dets = np.concatenate((dets, landms), axis=1)
#         print('decode time: {:.4f}'.format(time.time() - tic))
        
        
        return dets
