"""
    generate positive, negative, positive images whose size are 24*24 from Pnet and feed into RNet
"""
import sys
sys.path.append('..')
import cv2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
from utils.util import*
import torch
import random
from imutils import paths
from MTCNN import MTCNN_main
import shutil

def prepare_data(root_path):
    d = {}
    with open("MTCNN_train_config.txt") as f:
        for line in f:
            (key, val) = line.split(',')[0].split('=')
            d[key] = val
        f.close()
        
    if (not root_path.startswith('/notebooks')):
        path='/notebooks'
    else:    
        path='/'
    for i in root_path.split('/'):
        path=os.path.join(path,i)
        if i == 'MTCNN':
            break
    root_path=path
    for data_set in ["tra","val"]:
        img_dir = os.path.join(d['save_dir'],data_set)
        print(img_dir)
        pos_save_dir = os.path.join(root_path,"data_set",data_set,"Onet","positive")
        part_save_dir = os.path.join(root_path,"data_set",data_set,"Onet","part")
        neg_save_dir = os.path.join(root_path,"data_set",data_set,"Onet","negative")
        land_save_dir = os.path.join(root_path,"data_set",data_set,"Onet","land")


        if not os.path.exists(pos_save_dir):
            os.makedirs(pos_save_dir)
        else:
            shutil.rmtree(pos_save_dir, ignore_errors=True)
            os.makedirs(pos_save_dir)

        if not os.path.exists(part_save_dir):
            os.makedirs(part_save_dir)
        else:
            shutil.rmtree(part_save_dir, ignore_errors=True)
            os.makedirs(part_save_dir)        

        if not os.path.exists(neg_save_dir):
            os.makedirs(neg_save_dir)
        else:
            shutil.rmtree(neg_save_dir, ignore_errors=True)
            os.makedirs(neg_save_dir)

        if not os.path.exists(land_save_dir):
            os.makedirs(land_save_dir)
        else:
            shutil.rmtree(land_save_dir, ignore_errors=True)
            os.makedirs(land_save_dir)

        # store labels of positive, negative, part images
        f1 = open(os.path.join(root_path,'data_preprocessing','anno_store', 'pos_Onet_%s.txt' % (data_set)), 'w')
        f2 = open(os.path.join(root_path,'data_preprocessing','anno_store', 'neg_Onet_%s.txt' % (data_set)), 'w')
        f3 = open(os.path.join(root_path,'data_preprocessing','anno_store', 'part_Onet_%s.txt' % (data_set)), 'w')
        f4=open(os.path.join(root_path,'data_preprocessing','anno_store', 'land_Pnet_%s.txt' % (data_set)), 'w')

        # anno_file: store labels of the wider face training data
        img_paths = []
        img_paths += [el for el in paths.list_images(img_dir)]
        random.shuffle(img_paths)
        num = len(img_paths)
        print("%d pics in total" % num)

        image_size = (94, 24)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        p_model_path=os.path.join(root_path,'train','pnet_Weights')
        mtcnn_net=MTCNN_main(device,p_model_path,None)
        if (device=="cpu"):
            print("Training with CPU")
        else:
            print("Training with GPU")
        p_idx = 0 # positive
        n_idx = 0 # negative
        d_idx = 0 # dont care
        l_idx=0 # landmark
        idx = 0
        for annotation in img_paths:
            im_path = annotation
            print(im_path)

            basename = os.path.basename(im_path)
            imgname, suffix = os.path.splitext(basename)
            imgname_split = imgname.split('-')
            rec_x1y1 = imgname_split[2].split('_')[0].split('&')
            rec_x2y2 = imgname_split[2].split('_')[1].split('&')  
            x1, y1, x2, y2 = int(rec_x1y1[0]), int(rec_x1y1[1]), int(rec_x2y2[0]), int(rec_x2y2[1])

            gt=[]
            for i in range(4):
                gt.append(imgname_split[3].split("_")[i].split("&")[0])
                gt.append(imgname_split[3].split("_")[i].split("&")[1])
            gt=np.array(gt,dtype=np.int32)

            boxes = np.zeros((1,4), dtype=np.int32)
            boxes[0,0], boxes[0,1], boxes[0,2], boxes[0,3] = x1, y1, x2, y2

            image = cv2.imread(im_path)
    #         try:
    #             bboxes = create_mtcnn_net(image, 50, device, p_model_path=os.path.join(root_path,'train','pnet_Weights'), o_model_path=None)
    #         except Exception as e:
    #             print(e)
    #             continue
            bboxes = mtcnn_net.detect_plate(image, [100,100], True)

            dets = np.round(bboxes[:, 0:4])

            if dets.shape[0] == 0:
                continue

            img = cv2.imread(im_path)
            idx += 1
            try:
                height, width, channel = img.shape
            except:
                continue
            for box in dets:
                x_left, y_top, x_right, y_bottom = box[0:4].astype(int)
                width = x_right - x_left + 1
                height = y_bottom - y_top + 1

                # ignore box that is too small or beyond image border
                if width < 20 or x_left < 0 or y_top < 0 or x_right > img.shape[1] - 1 or y_bottom > img.shape[0] - 1:
                    continue

                # compute intersection over union(IoU) between current box and all gt boxes
                Iou = IoU(box, boxes)
                cropped_im = img[y_top:y_bottom + 1, x_left:x_right + 1, :]
                try:
                    if not(cropped_im.shape[0]>0 and cropped_im.shape[1]>0 and cropped_im.shape[2]>0):
                        continue
                except:
                    continue
                resized_im = cv2.resize(cropped_im, image_size, interpolation=cv2.INTER_LINEAR)


                # save negative images and write label
                if np.max(Iou) < 0.3 and n_idx < 3.2*p_idx+1:
                    # Iou with all gts must below 0.3
                    save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx)
                    f2.write(save_file + ' 0\n')
                    cv2.imwrite(save_file, resized_im)
                    n_idx += 1
                else:
                    # find gt_box with the highest iou
                    idx_Iou = np.argmax(Iou)
                    assigned_gt = boxes[idx_Iou]
                    x1, y1, x2, y2 = assigned_gt

                    # compute bbox reg label
                    offset_x1 = (x1 - x_left) / float(width)
                    offset_y1 = (y1 - y_top) / float(height)
                    offset_x2 = (x2 - x_right) / float(width)
                    offset_y2 = (y2 - y_bottom) / float(height)

                    # save positive and part-face images and write labels
                    if np.max(Iou) >= 0.65:
                        save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx)
                        f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % (
                            offset_x1, offset_y1, offset_x2, offset_y2))
                        cv2.imwrite(save_file, resized_im)
                        p_idx += 1


                        ## landmark
                        p1x,p1y,p2x,p2y,p3x,p3y,p4x,p4y=data_sorted(gt)

                        offset_p1x=(p1x-x_left) / float(width)
                        offset_p1y=(p1y-y_top) / float(height)

                        offset_p2x=(p2x-x_left) / float(width)
                        offset_p2y=(p2y-y_top) / float(height)

                        offset_p3x=(p3x-x_left) / float(width)
                        offset_p3y=(p3y-y_top) / float(height)

                        offset_p4x=(p4x-x_left) / float(width)
                        offset_p4y=(p4y-y_top) / float(height)
                        save_file = os.path.join(land_save_dir, "%s.jpg" % l_idx)
                        f4.write(save_file + ' -2 %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f\n' % (offset_x1,offset_y1,offset_x2,offset_y2,offset_p1x, offset_p1y,offset_p2x,offset_p2y,offset_p3x,offset_p3y,offset_p4x,offset_p4y))
                        cv2.imwrite(save_file, resized_im)
                        l_idx+=1

                    elif np.max(Iou) >= 0.4 and d_idx < 1.2*p_idx + 1:
                        save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx)
                        f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % (
                            offset_x1, offset_y1, offset_x2, offset_y2))
                        cv2.imwrite(save_file, resized_im)
                        d_idx += 1

            print("%s images done, pos: %s part: %s neg: %s landmark: %s" % (idx, p_idx, d_idx, n_idx,l_idx))

        f1.close()
        f2.close()
        f3.close()
    del(mtcnn_net)
    torch.cuda.empty_cache()    
    return True

def data_sorted(gt):
    p1x,p1y,p2x,p2y,p3x,p3y,p4x,p4y=gt
    p1=(p1x,p1y)
    p2=(p2x,p2y)
    p3=(p3x,p3y)
    p4=(p4x,p4y)
    l=[]
    l.append(p1)
    l.append(p2)
    l.append(p3)
    l.append(p4)
    l=sorted(l,key=(lambda x:x[0]))
    a=sorted(l[:2],key=(lambda x:x[1]))
    b=sorted(l[2:4],key=(lambda x:x[1]))
    r=a+b
    rr=[]
    for i in r:
        for j in i:
            rr.append(j)
    return rr




