import os
import os.path
import sys
import torch
import torch.utils.data as data
import cv2
import numpy as np

class WiderFaceDetection(data.Dataset):
    def __init__(self, training_dataset, preproc=None):
        self.preproc = preproc
        self.imgs_path = []
        self.words = []
        for root,folders,files in os.walk(training_dataset):
            for f in files:
                img_path=os.path.join(root,f)
                file=f.split('.')[0]                
                label=file.split('-')[1]
                
                # bbox
                lxy,rxy=file.split('-')[2].split('_')
                x1=int(lxy.split('&')[0])
                y1=int(lxy.split('&')[1])
                x2=int(rxy.split('&')[0])
                y2=int(rxy.split('&')[1])

                l0,l1,l2,l3=file.split('-')[3].split('_')
                # landmarks
                landm=[]
                landm.append(int(l0.split('&')[0]))    # l0_x
                landm.append(int(l0.split('&')[1]))     # l0_y
                landm.append(int(l1.split('&')[0]))     # l1_x
                landm.append(int(l1.split('&')[1]))     # l1_y
                landm.append(int(l2.split('&')[0]))     # l2_x
                landm.append(int(l2.split('&')[1]))     # l2_y
                landm.append(int(l3.split('&')[0]))    # l3_x
                landm.append(int(l3.split('&')[1]))    # l3_y
                l0_x,l0_y,l1_x,l1_y,l2_x,l2_y,l3_x,l3_y=landm_sorted(landm)
                
                if(label=="pos_sam"):
                    # 正樣本
                    self.words.append([(x1,y1,x2,y2,l0_x,l0_y,l1_x,l1_y,l2_x,l2_y,l3_x,l3_y,1)])
                else:
                    self.words.append([(x1,y1,x2,y2,-1,-1,-1,-1,-1,-1,-1,-1,0)])

                self.imgs_path.append(img_path)

    def __len__(self):
        return len(self.imgs_path)

    def __getitem__(self, index):
        '''
        img = cv2.imread(self.imgs_path[index],0)  ## 參數 0 統一轉灰階,etc黑白 與 彩色測桿綜合訓練
        img=cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)  ## 統一轉灰階,etc黑白 與 彩色測桿綜合訓練
        '''
        img = cv2.imread(self.imgs_path[index])

        height, width, _ = img.shape

        labels = self.words[index]
        annotations = np.zeros((0, 13))
        if len(labels) == 0:
            return annotations
        for idx, label in enumerate(labels):
            annotation = np.zeros((1, 13))
            # bbox
            annotation[0, 0] = label[0]  # x1
            annotation[0, 1] = label[1]  # y1
            annotation[0, 2] = label[2]  # x2
            annotation[0, 3] = label[3]  # y2

            # landmarks
            annotation[0, 4] = label[4]    # l0_x
            annotation[0, 5] = label[5]    # l0_y
            annotation[0, 6] = label[6]    # l1_x
            annotation[0, 7] = label[7]    # l1_y
            annotation[0, 8] = label[8]   # l2_x
            annotation[0, 9] = label[9]   # l2_y
            annotation[0, 10] = label[10]  # l3_x
            annotation[0, 11] = label[11]  # l3_y
            
            if (annotation[0, 4]<0):
                annotation[0, 12] = -1
            else:
                annotation[0, 12] = 1

            annotations = np.append(annotations, annotation, axis=0)
        target = np.array(annotations)
        if self.preproc is not None:
            img, target = self.preproc(img, target)

        return torch.from_numpy(img), target

def detection_collate(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).

    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations

    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on 0 dim
    """
    targets = []
    imgs = []
    for _, sample in enumerate(batch):
        for _, tup in enumerate(sample):
            if torch.is_tensor(tup):
                imgs.append(tup)
            elif isinstance(tup, type(np.empty(0))):
                annos = torch.from_numpy(tup).float()
                targets.append(annos)

    return (torch.stack(imgs, 0), targets)

def landm_sorted(gt):
    p1x,p1y,p2x,p2y,p3x,p3y,p4x,p4y=gt
    p1=(p1x,p1y)
    p2=(p2x,p2y)
    p3=(p3x,p3y)
    p4=(p4x,p4y)
    l=[]
    l.append(p1)
    l.append(p2)
    l.append(p3)
    l.append(p4)
    l=sorted(l,key=(lambda x:x[0]))
    a=sorted(l[:2],key=(lambda x:x[1]))
    b=sorted(l[2:4],key=(lambda x:x[1]))
    r=a+b
    rr=[]
    for i in r:
        for j in i:
            rr.append(j)
    return rr
