from tool.utils import *
from tool.torch_utils import *
from tool.darknet2pytorch import Darknet
import cv2

"""hyper parameters"""
use_cuda = True
# _BASE_DIR = os.path.dirname(os.path.abspath("./"))
_BASE_DIR='./'
cfgfile = os.path.join(_BASE_DIR,"cfg/yolov4.cfg")
weightfile= os.path.join(_BASE_DIR,"weights","yolov4.weights")

model=None
params=["img","image_path"]

def model_init():
    global model
    model = Darknet(cfgfile)
    model.load_weights(weightfile)

def detect_car(**kwargs):
    image=None
    image_path=None
    for k, v in kwargs.items():
        if not k in params:
            raise("No Such Argument : %s" % k)        
        image=v if k==params[0] else image
        image_path=v if k==params[1] else image_path
    if image == None and image_path != None:
        image = cv2.imread(image_path)
    elif not(image == None ^ image_path == None):
        raise("Please input either an img or image_path")

    if model==None:
        model_init()
#         print(model.print_network())
        if use_cuda:
            model.cuda()

    num_classes = model.num_classes
    if num_classes == 20:
        namesfile = 'data/voc.names'
    elif num_classes == 80:
        namesfile = 'data/coco.names'
    else:
        namesfile = 'data/custom.names'
    class_names = load_class_names(os.path.join(_BASE_DIR,namesfile))


    sized = cv2.resize(image, (model.width, model.height))
    sized = cv2.cvtColor(sized, cv2.COLOR_BGR2RGB)
#     start = time.time()
    boxes = do_detect(model, sized, 0.4, 0.6, use_cuda)
#     finish = time.time()
#     print('%s: Predicted in %f seconds.' % (imgfile, (finish - start)))
        
#   ###draw_img
    plot_boxes_cv2(image, boxes[0], savename='predictions.jpg', class_names=class_names)

#     width = image.shape[1]
#     height = image.shape[0]
#     boxes=boxes[0]
#     results=[]
#     for i in range(len(boxes)):
#         box = boxes[i]
#         x1 = int(box[0] * width)
#         y1 = int(box[1] * height)
#         x2 = int(box[2] * width)
#         y2 = int(box[3] * height)
        
#         if len(box) >= 7 and class_names:
#             cls_conf = box[5]
#             cls_id = box[6]
#             if cls_id in [2,3,5,7]:
#                 car=image[y1:y2,x1:x2]
#                 results.append((car,{'x':x1,'y':y1,'w':x2-x1,'h':y2-y1,'type':class_names[cls_id]},cls_conf))
#         else:
#             results.append((None))
#     return results