import torch
import numpy as np
from .data import cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from utils.nms.py_cpu_nms import py_cpu_nms
from utils.box_utils import decode, decode_landm
from Pytorch_Retinaface import args

def decode_box(loc,conf,landms,prior_data=None,cfg = None,device=None):
  if device==None:
    device = torch.device("cpu" if args.cpu else "cuda")
  if cfg == None:
    if args.network == "mobile0.25":
      cfg = cfg_mnet
  if prior_data == None:
    priorbox = PriorBox(cfg, image_size=(im_height, im_width))
    priors = priorbox.forward()
    priors = priors.to(self.device)
    prior_data = priors.data

  # Conver to torch tensor cuda
  loc=torch.from_numpy(loc).to(device)        
  conf=torch.from_numpy(conf).to(device)
  landms=torch.from_numpy(landms).to(device)            

  boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])            
  boxes = boxes * scale / resize
  boxes = boxes.cpu().numpy()
  scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
  landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
  scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
                               img.shape[3], img.shape[2], img.shape[3], img.shape[2],
                               ])
  scale1 = scale1.to(device)
  landms = landms * scale1 / resize
  landms = landms.cpu().numpy()

  # ignore low scores
  inds = np.where(scores > args.confidence_threshold)[0]
  boxes = boxes[inds]
  landms = landms[inds]
  scores = scores[inds]

  # keep top-K before NMS
  order = scores.argsort()[::-1][:args.top_k]
  boxes = boxes[order]
  landms = landms[order]
  scores = scores[order]

  # do NMS
  dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
  keep = py_cpu_nms(dets, args.nms_threshold)
  # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
  dets = dets[keep, :]
  landms = landms[keep]

  # keep top-K faster NMS
  dets = dets[:args.keep_top_k, :]
  landms = landms[:args.keep_top_k, :]

  dets = np.concatenate((dets, landms), axis=1)
        
  return dets
