#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 12 09:37:49 2019

@author: xingyu
"""

"""
    generate positive, negative, positive images whose size are 12*47 and feed into PNet
"""
import sys
sys.path.append('..')
import cv2
import random
import os
import numpy as np
from utils.util import*
from imutils import paths
import shutil


def prepare_data(root_path):
    d = {}
    with open("MTCNN_train_config.txt") as f:
        for line in f:
            (key, val) = line.split(',')[0].split('=')
            d[key] = val
        f.close()
    if (not root_path.startswith('/notebooks')):
        path='/notebooks'
    else:    
        path='/'
    for i in root_path.split('/'):
        path=os.path.join(path,i)
        if i == 'MTCNN':
            break
    root_path=path
    for data_set in ["tra","val"]:
        img_dir = os.path.join(d['save_dir'],data_set)
        pos_save_dir = os.path.join(root_path,"data_set",data_set,"Pnet","positive")
        part_save_dir = os.path.join(root_path,"data_set",data_set,"Pnet","part")
        neg_save_dir = os.path.join(root_path,"data_set",data_set,"Pnet","negative")
        land_save_dir = os.path.join(root_path,"data_set",data_set,"Pnet","land")

        if not os.path.exists(pos_save_dir):
            os.makedirs(pos_save_dir)
        else:
            shutil.rmtree(pos_save_dir, ignore_errors=True)
            os.makedirs(pos_save_dir)

        if not os.path.exists(part_save_dir):
            os.makedirs(part_save_dir)
        else:
            shutil.rmtree(part_save_dir, ignore_errors=True)
            os.makedirs(part_save_dir)        

        if not os.path.exists(neg_save_dir):
            os.makedirs(neg_save_dir)
        else:
            shutil.rmtree(neg_save_dir, ignore_errors=True)
            os.makedirs(neg_save_dir)

        if not os.path.exists(land_save_dir):
            os.makedirs(land_save_dir)
        else:
            shutil.rmtree(land_save_dir, ignore_errors=True)
            os.makedirs(land_save_dir)

        # store labels of positive, negative, part images
        f1 =open(os.path.join(root_path,'data_preprocessing','anno_store', 'pos_Pnet_%s.txt' % (data_set)), 'w')
        f2 =open(os.path.join(root_path,'data_preprocessing','anno_store', 'neg_Pnet_%s.txt' % (data_set)), 'w')
        f3 =open(os.path.join(root_path,'data_preprocessing','anno_store', 'part_Pnet_%s.txt' % (data_set)), 'w')
        f4=open(os.path.join(root_path,'data_preprocessing','anno_store', 'land_Pnet_%s.txt' % (data_set)), 'w')

        img_paths = []
        img_paths += [el for el in paths.list_images(img_dir)]
        random.shuffle(img_paths)
        num = len(img_paths)
        print("%d pics in total" % num)

        p_idx = 0 # positive
        n_idx = 0 # negative
        d_idx = 0 # dont care
        l_idx =0 # landmark
        idx = 0
        for annotation in img_paths:
            im_path = annotation
            print(im_path)

            basename = os.path.basename(im_path)
            imgname, suffix = os.path.splitext(basename)
            imgname_split = imgname.split('-')
            rec_x1y1 = imgname_split[2].split('_')[0].split('&')
            rec_x2y2 = imgname_split[2].split('_')[1].split('&')  
            x1, y1, x2, y2 = int(rec_x1y1[0]), int(rec_x1y1[1]), int(rec_x2y2[0]), int(rec_x2y2[1])
            gt=[]
            for i in range(4):
                gt.append(imgname_split[3].split("_")[i].split("&")[0])
                gt.append(imgname_split[3].split("_")[i].split("&")[1])            
            gt=np.array(gt,dtype=np.int32)
    #         print(gt)
            boxes = np.zeros((1,4), dtype=np.int32)
            boxes[0,0], boxes[0,1], boxes[0,2], boxes[0,3] = x1, y1, x2, y2

            img = cv2.imread(im_path)
            idx += 1
            try:
                height, width, channel = img.shape
            except:
                print(img_paths)
                continue
            neg_num = 0
            while neg_num < 35:
                size_x = np.random.randint(47, min(width, height) / 2)
                size_y = np.random.randint(12, min(width, height) / 2)
                nx = np.random.randint(0, width - size_x)
                ny = np.random.randint(0, height - size_y)
                crop_box = np.array([nx, ny, nx + size_x, ny + size_y])

                Iou = IoU(crop_box, boxes)

                cropped_im = img[ny: ny + size_y, nx: nx + size_x, :]
                resized_im = cv2.resize(cropped_im, (47, 12), interpolation=cv2.INTER_LINEAR)

                if np.max(Iou) < 0.3:
                    # Iou with all gts must below 0.3
                    save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx)
                    f2.write(save_file + ' 0\n')
                    cv2.imwrite(save_file, resized_im)
                    n_idx += 1
                    neg_num += 1

            for box in boxes:
                # box (x_left, y_top, w, h)
                x1, y1, x2, y2 = box
                w = x2 - x1 + 1
                h = y2 - y1 + 1

                # generate negative examples that have overlap with gt
                for i in range(5):
                    size_x = np.random.randint(47, min(width, height) / 2)
                    size_y = np.random.randint(12, min(width, height) / 2)
                    # delta_x and delta_y are offsets of (x1, y1)
                    delta_x = np.random.randint(max(-size_x, -x1), w)
                    delta_y = np.random.randint(max(-size_y, -y1), h)
                    nx1 = max(0, x1 + delta_x)
                    ny1 = max(0, y1 + delta_y)

                    if nx1 + size_x > width or ny1 + size_y > height:
                        continue
                    crop_box = np.array([nx1, ny1, nx1 + size_x, ny1 + size_y])
                    Iou = IoU(crop_box, boxes)

                    cropped_im = img[ny1: ny1 + size_y, nx1: nx1 + size_x, :]
                    resized_im = cv2.resize(cropped_im, (47, 12), interpolation=cv2.INTER_LINEAR)

                    if np.max(Iou) < 0.3:
                        # Iou with all gts must below 0.3
                        save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx)
                        f2.write(save_file + ' 0\n')
                        cv2.imwrite(save_file, resized_im)
                        n_idx += 1
                # generate positive examples and part faces
                for i in range(20):
                    size_x = np.random.randint(int(min(w, h) * 0.8), np.ceil(1.25 * max(w, h)))
                    size_y = np.random.randint(int(min(w, h) * 0.8), np.ceil(1.25 * max(w, h)))

                    # delta here is the offset of box center
                    delta_x = np.random.randint(-w * 0.2, w * 0.2)
                    delta_y = np.random.randint(-h * 0.2, h * 0.2)

                    nx1 = max(x1 + w / 2 + delta_x - size_x / 2, 0)
                    ny1 = max(y1 + h / 2 + delta_y - size_y / 2, 0)
                    nx2 = nx1 + size_x
                    ny2 = ny1 + size_y

                    if nx2 > width or ny2 > height:
                        continue
                    crop_box = np.array([nx1, ny1, nx2, ny2])

                    offset_x1 = (x1 - nx1)/ float(size_x)
                    offset_y1 = (y1 - ny1) / float(size_y)
                    offset_x2 = (x2 - nx2) / float(size_x)
                    offset_y2 = (y2 - ny2) / float(size_y)

                    cropped_im = img[int(ny1): int(ny2), int(nx1): int(nx2), :]
                    resized_im = cv2.resize(cropped_im, (47, 12), interpolation=cv2.INTER_LINEAR)

                    box_ = box.reshape(1, -1)
                    if IoU(crop_box, box_) >= 0.65:
                        save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx)
                        f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % (offset_x1, offset_y1, offset_x2, offset_y2))
                        cv2.imwrite(save_file, resized_im)
                        p_idx += 1

                        ## landmark
                        p1x,p1y,p2x,p2y,p3x,p3y,p4x,p4y=data_sorted(gt)                    
                        offset_p1x=(p1x-nx1) / float(size_x)
                        offset_p1y=(p1y-ny1) / float(size_y)

                        offset_p2x=(p2x-nx1) / float(size_x)
                        offset_p2y=(p2y-ny1) / float(size_y)

                        offset_p3x=(p3x-nx1) / float(size_x)
                        offset_p3y=(p3y-ny1) / float(size_y)

                        offset_p4x=(p4x-nx1) / float(size_x)
                        offset_p4y=(p4y-ny1) / float(size_y)
                        save_file = os.path.join(land_save_dir, "%s.jpg" % l_idx)
                        f4.write(save_file + ' -2 %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f\n' % (offset_x1,offset_y1,offset_x2,offset_y2,offset_p1x, offset_p1y,offset_p2x,offset_p2y,offset_p3x,offset_p3y,offset_p4x,offset_p4y))
                        cv2.imwrite(save_file, resized_im)
                        l_idx+=1
                    elif IoU(crop_box, box_) >= 0.4 and d_idx < 1.2*p_idx + 1:
                        save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx)
                        f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % (offset_x1, offset_y1, offset_x2, offset_y2))
                        cv2.imwrite(save_file, resized_im)
                        d_idx += 1



            print("%s images done, pos: %s part: %s neg: %s landmark %s" % (idx, p_idx, d_idx, n_idx,l_idx))

        f1.close()
        f2.close()
        f3.close()
        f4.close()
    return True

def data_sorted(gt):
    p1x,p1y,p2x,p2y,p3x,p3y,p4x,p4y=gt
    p1=(p1x,p1y)
    p2=(p2x,p2y)
    p3=(p3x,p3y)
    p4=(p4x,p4y)
    l=[]
    l.append(p1)
    l.append(p2)
    l.append(p3)
    l.append(p4)
    l=sorted(l,key=(lambda x:x[0]))
    a=sorted(l[:2],key=(lambda x:x[1]))
    b=sorted(l[2:4],key=(lambda x:x[1]))
    r=a+b
    rr=[]
    for i in r:
        for j in i:
            rr.append(j)
    return rr
    

    