from __future__ import print_function
import sys
import cv2
from random import randint
from threading import Thread
import time
import datetime
import requests
import numpy as np


trackerTypes = ['BOOSTING', 'MIL', 'KCF','TLD', 'MEDIANFLOW', 'GOTURN', 'MOSSE', 'CSRT']

def createTrackerByName(trackerType):
    # Create a tracker based on tracker name
    if trackerType == trackerTypes[0]:
        tracker = cv2.TrackerBoosting_create()
    elif trackerType == trackerTypes[1]: 
        tracker = cv2.TrackerMIL_create()
    elif trackerType == trackerTypes[2]:
        tracker = cv2.TrackerKCF_create()
    elif trackerType == trackerTypes[3]:
        tracker = cv2.TrackerTLD_create()
    elif trackerType == trackerTypes[4]:
        tracker = cv2.TrackerMedianFlow_create()
    elif trackerType == trackerTypes[5]:
        tracker = cv2.TrackerGOTURN_create()
    elif trackerType == trackerTypes[6]:
        tracker = cv2.TrackerMOSSE_create()
    elif trackerType == trackerTypes[7]:
        tracker = cv2.TrackerCSRT_create()
    else:
        tracker = None
        print('Incorrect tracker name')
        print('Available trackers are:')
        for t in trackerTypes:
            print(t)
    return tracker

def osd_draw(frame,obj_bbox,color=(0,0,255)):
    for obj_name in obj_bbox:
        x1,y1,w,h=[int(i) for i in obj_bbox[obj_name]]
        x2=w+x1
        y2=h+y1
        frame=cv2.rectangle(frame,(x1,y1),(x2,y2),color,3)
        frame=cv2.rectangle(frame, (x1, y2), (x2+180, y2+50), (255, 0, 0), -1)  # 畫上面 y1 畫下面 設y2
        frame=cv2.putText(frame, obj_name, (x1, y2+40), cv2.FONT_HERSHEY_PLAIN,3, (255, 255, 255), 2, cv2.LINE_AA)
    return frame

def IOU(ax1,ay1,aw,ah,bx1,by1,bw,bh):
    area_a = aw * ah
    area_b = bw * bh
    
    w = min(bx1+bw,ax1+aw) - max(ax1,bx1)
    h = min(by1+bh,ay1+ah) - max(ay1,by1)
    
    if w <= 0 or h <= 0:
        return 0
    area_c = w * h
    return area_c / (area_a + area_b - area_c)

def post_api(dtime,plate,location,imgbytes):
    print(datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S"),", %s ,upload to API"%plate)
    url = "http://192.168.5.102:8000/test"
    payload = {'time': dtime, 
    'location': location, 
    'plate': plate,
    'filename':plate
    } 
    files = [ 
    ('file', imgbytes)
    ] 
    headers= {}

    response = requests.request("POST", url, headers=headers, data = payload, files = files) 
    if response.status_code !=200:
        print (response.text.encode('utf8'))
    return response.status_code

class Muti_tracker(object):
    '''
    Muti_tracker manages all sub-trackers.
    '''
    def __init__(self,trackers_type="CSRT",max_toler=5,source=""):        
        self.trackers=dict()
        '''
        {1 : [tracker,obj_names,bbox],...} 
            tracker::cv2 tracker object, 
            obj_names::<string>:<int>, 
            bbox::(xywh)
            tolerance_frames::int  (while prediction loss but still be tracked)
        '''
        self.objects_count=0
        self.trackers_type=trackers_type
        self.max_toler=max_toler
        self.source=source      # source of video stream
        self.current_frame=None
        self.log=True
        self.__thread_logloop = Thread(target=self.__logloop)
        self.__thread_logloop.start()


    
    '''
    Add new sub-tracker & init to trackers dict
    '''
    def add(self,frame,prediction,name=""):
        if name=="":
            name=f"object{self.objects_count}"
        tracker=createTrackerByName(self.trackers_type)
        suc=tracker.init(frame,prediction)
        if suc :
            self.objects_count+=1
            index=0
            for index in self.trackers:
                if self.trackers[index]==[]:  # 新 tracker 填入空位
                    break
                else:
                    index+=1                  # 新 tracker 填入字典最後
            obj_names=dict()
            obj_names[name]=1
            tolerance_frames=0
            self.trackers[index]=[tracker,obj_names,prediction,tolerance_frames]
            self.__thread_logjob = Thread(target=self.__logjob, args = (index,))
            self.__thread_logjob.start()
            
    '''
    Update all availabe sub-trackers in trackers dict, and tick-off lose tracking tracker.
    '''
    def update(self,frame):
        bbox_name=dict()
        self.current_frame=frame
        is_objtect=False
        for index in self.trackers:
            if self.trackers[index]==[]:
                continue
            suc, bbox = self.trackers[index][0].update(frame)
            if suc:
                is_objtect=True
                self.trackers[index][2]=bbox
                name=self.get_obj_name(self.trackers[index][1],index)
                bbox_name[name]=bbox
            else:
                print("track loss")
                self.trackers[index]=[]   # tracking loss
        return is_objtect,bbox_name
    
    '''
    Correct all availabe sub-trackers with IOU.
    '''
    def correct(self,frame,prediction,objs_name,threshold=0.5):
        self.current_frame=frame
        nums_of_trackers=len(self.trackers)
        tra_index=0
        if len(prediction) != 0:
            while tra_index < nums_of_trackers:
                if self.trackers[tra_index]==[]:
                    tra_index+=1
                    continue            
                pre_index=0
                match_flag=False
                while pre_index < len(prediction):           
                    bbox=prediction[pre_index]
                    ax1,ay1,aw,ah=bbox
                    bx1,by1,bw,bh=self.trackers[tra_index][2]
                    iou=IOU(ax1,ay1,aw,ah,bx1,by1,bw,bh)
                    if iou >= 1-threshold:
                        tracker=createTrackerByName(self.trackers_type)
                        suc=tracker.init(frame,bbox)
                        if suc :
                            match_flag=True
                            self.trackers[tra_index][3]=0       # init the non-matching tolerance frames
                            self.trackers[tra_index][0]=tracker # reset a new tracker
                            if self.trackers[tra_index][1].get(objs_name[pre_index]):
                                if self.trackers[tra_index][1][objs_name[pre_index]]<= 999:
                                    self.trackers[tra_index][1][objs_name[pre_index]]+=1
                            else:
                                self.trackers[tra_index][1][objs_name[pre_index]]=1
                            objs_name.pop(pre_index)
                            prediction.pop(pre_index)
                            break
                    pre_index+=1
                if not match_flag:  # del non-matching tracker
                    self.trackers[tra_index][3]+=1
                    if self.trackers[tra_index][3] >self.max_toler:  # prediction loss clear tracker
                        self.trackers[tra_index]=[]
                tra_index+=1
            for n,bbox in enumerate(prediction):  # add new tracker for others prediction objects
                self.add(frame,bbox,objs_name[n])

        return self.update(frame)
    
    def get_obj_name(self,names,index):
        '''
        names=dict() {"Name":int}
        '''
        max_index=""
        max_v=0
        for i in names:
            if max_v ==1000:
                # init obj_names
                obj_names=dict()
                obj_names[max_index]=1
                self.trackers[index][1]=obj_names
                break
            if names[i]>max_v:
                max_v=names[i]
                max_index=i
        return max_index

    def __logloop(self):
        while (self.log):
            time.sleep(5)
            if np.shape(self.current_frame) != ():
                current_frame=self.current_frame.copy()
                for index in self.trackers:
                    try:
                        dtime = datetime.datetime.now().timestamp()
                        plate=self.get_obj_name(self.trackers[index][1],index)
                        location=self.source
                        x1,y1,w,h=[int(i) for i in self.trackers[index][2]]
                        img_crop=current_frame[y1:y1+h,x1:x1+w]
                        imgbytes=cv2.imencode('.jpg', img_crop)[1].tostring()  ## 編碼成 bytes
                        post_api(dtime,plate,location,imgbytes)
                    except Exception as e:
                        pass
#                         print(e)
                        

    def __logjob(self,index):
        try:
            if np.shape(self.current_frame) != ():
                current_frame=self.current_frame.copy()
                dtime = datetime.datetime.now().timestamp()
                plate=self.get_obj_name(self.trackers[index][1],index)
                location=self.source
                x1,y1,w,h=[int(i) for i in self.trackers[index][2]]
                img_crop=current_frame[y1:y1+h,x1:x1+w]
                imgbytes=cv2.imencode('.jpg', img_crop)[1].tostring()  ## 編碼成 bytes
                post_api(dtime,plate,location,imgbytes)
        except Exception as e:
            pass            
#             print(e)

            
        
        
    def stop(self):
        self.log=False
            
            
        


        
        
        
        

