from xml.dom import minidom 
import traceback
import cv2
import os
import numpy as np
import argparse
import shutil
import random

def readXml(xml_path,save_dir,imgs_folder,count,resize=False):
    try:
        doc=minidom.parse(xml_path)
        size=doc.getElementsByTagName("size")[0]
        max_x=int(size.getElementsByTagName("width")[0].firstChild.data)
        max_y=int(size.getElementsByTagName("height")[0].firstChild.data)
        
        img_name=doc.getElementsByTagName("filename")[0].firstChild.data            
        img_path=os.path.join(imgs_folder,img_name)
        img=cv2.imread(img_path)
        shape=np.shape(img)
        try:
            if shape[0]>=2000 or shape[1]>=2000:
                resize=True
            if shape[0]<=200 or shape[1]<=200:
                return False
        except:
            return False
        if resize:
            img=cv2.resize(img,None, fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR)
        
        objs=doc.getElementsByTagName("object")
        for n,obj in enumerate(objs):
            x1=None
            y1=None
            x2=None
            y2=None
            x1=(int)(obj.getElementsByTagName("xmin")[0].firstChild.data)
            if x1<0:return False
            if x1>max_x:return False
            y1=(int)(obj.getElementsByTagName("ymin")[0].firstChild.data)
            if y1<0:return False
            if y1>max_y:return False
            x2=(int)(obj.getElementsByTagName("xmax")[0].firstChild.data)
            if x2<0:return False
            if x2>max_x:return False            
            y2=(int)(obj.getElementsByTagName("ymax")[0].firstChild.data)
            if y2<0:return False
            if y2>max_y:return False
            if resize:
                x1//=2
                y1//=2
                x2//=2
                y2//=2
    #         print(x1,y1,x2,y2)
            p1=None
            p2=None
            p3=None
            p4=None        
            lands=obj.getElementsByTagName("landmark")
            for land in lands:
                p1=list(map(int,land.getElementsByTagName("p1")[0].firstChild.data.split(",")))
                p2=list(map(int,land.getElementsByTagName("p2")[0].firstChild.data.split(",")))
                p3=list(map(int,land.getElementsByTagName("p3")[0].firstChild.data.split(",")))
                p4=list(map(int,land.getElementsByTagName("p4")[0].firstChild.data.split(",")))
                for j in range(0,2):
                    if p1[j] < 0:return False                    
                    if p2[j] < 0:return False
                    if p3[j] < 0:return False
                    if p4[j] < 0:return False
                    if resize:
                        p1[j]//=2
                        p2[j]//=2
                        p3[j]//=2
                        p4[j]//=2
                if p1[0] > max_x:return False
                if p2[0] > max_x:return False 
                if p3[0] > max_x:return False 
                if p4[0] > max_x:return False
                if p1[1] > max_y:return False
                if p2[1] > max_y:return False 
                if p3[1] > max_y:return False 
                if p4[1] > max_y:return False
                
            
            if (x1==None or y1==None or x2==None or y2==None):
                return False
            if (p1==None or p2==None or p3==None or p4==None):
                return False
            if (len(p1)<2 or len(p2)<2 or len(p3)<2 or len(p4)<2):
                return False                
                
            fielname="{0}_{13}-pos_sam-{1}&{2}_{3}&{4}-{5}&{6}_{7}&{8}_{9}&{10}_{11}&{12}.jpg".format(str(count),x1,y1,x2,y2,p1[0],p1[1],p2[0],p2[1],p3[0],p3[1],p4[0],p4[1],n)
            print(img_path)
            try:
                if shape[0]>0 and shape[1]>0 and shape[2]>0:
                    cv2.imwrite(os.path.join(save_dir,fielname),img)
            except Exception as e:
                print(shape)
#                 print(e)
                return False
        return True        
    except Exception as e:
        print("---Annotation data not completed---")
        print("---Fail_path:",xml_path)
        raise(e)        
        return False      
        
        
def process():
    d = {}
    count=0
    with open("MTCNN_train_config.txt") as f:
        for line in f:
            (key, val) = line.split(',')[0].split('=')
            print(line.split(',')[0].split('='))
            d[key] = val            
        f.close()
    if not os.path.isdir(d['save_dir']):
        os.mkdir(d['save_dir'])
        print("create new image folder for MTCNN:",d['save_dir'])        
#     else:
#         print("Initial image folder for MTCNN:",d['save_dir'])                
#         shutil.rmtree(d['save_dir'])        
#         os.makedirs(d['save_dir'])
        
    for root,dirnames,filenames in os.walk(d['xml_path']):
        for filename  in  filenames:
            if filename.endswith(".xml") :
                if (readXml(os.path.join(d['xml_path'],filename),d['save_dir'],d['imgs_folder'],count)):
                    count+=1         

    print(d['xml_path'])
    print(d['save_dir'])
    print(d['imgs_folder'])
    return count
            
