# import common

# from retina_utils.TrtRetinaPlate import TrtRetinaPlate
import sys
import os
root_path=os.path.dirname(__file__)
sys.path.append(root_path)
# from retina_utils.TrtRetinaGrayPlate import TrtRetinaPlate as TrtRetinaGrayPlate
# from retina_utils.TrtRetinaGrayPlateSHM import TrtRetinaPlateSHM as TrtRetinaGrayPlateSHM

from retina_utils.alignment import align
from retina_utils.TrtRetinaColorPlate import TrtRetinaPlate as TrtRetinaColorPlate
from retina_utils.TrtRetinaColorPlateSHM import TrtRetinaPlateSHM as TrtRetinaColorPlateSHM 


from ocr_utils.TrtCtcOcr import TrtCtcOcr
from ocr_utils.TrtCtcOcrSHM import TrtCtcOcrSHM

import cv2
import numpy as np
import time
import threading, queue

class TrtRetinaPlateThread(threading.Thread):
    """TrtThread
    This implements the child thread which continues to read images
    from cam (input) and to do TRT engine inferencing.  The child
    thread stores the input image and detection results into global
    variables and uses a condition varaiable to inform main thread.
    In other words, the TrtThread acts as the producer while the
    main thread is the consumer.

    """
    def __init__(self, condition, vehicles_q, plates_q, conf_th=0.9, budget=10):
        """__init__
        # Arguments
            condition: the condition variable used to notify main
                       thread about new frame and detection result
            cam: the camera object for reading input image frames
            model: a string, specifying the TRT SSD model
            conf_th: confidence threshold for detection
        """
        threading.Thread.__init__(self)
        self.condition = condition
        self.vehicles_q = vehicles_q
        self.plates_q = plates_q
        self.conf_th = conf_th
        self.trt_plate = None   # to be created when run
        self.running = False
        self.alignment = align

    def run(self):
        """Run until 'running' flag is set to False by main thread.
        NOTE: CUDA context is created here, i.e. inside the thread
        which calls CUDA kernels.  In other words, creating CUDA
        context in __init__() doesn't work.
        """
        global s_img, s_boxes, s_confs, s_lands

        print('TrtThread: loading the TRT RetinaPlate engine...')
#         self.trt_plate = TrtRetinaGrayPlate((320,320))
#        self.trt_plate = TrtRetinaGrayPlateSHM((320,320))
        self.trt_plate = TrtRetinaColorPlateSHM((320,320))
        
        print('TrtThread: TRT RetinaPlate start running...')
        self.running = True
        while self.running:
            if(self.vehicles_q.empty()):
                time.sleep(1)
                continue
            info = self.vehicles_q.get()
            if(not isinstance(info, VehicleInfo)):
                continue
            if(info.img is None):
                continue
            output = self.trt_plate.detect(info.img, self.conf_th)
            if(output.shape[0] > 0):
                boxes = output[:,:4].astype(int)
                confs = output[:,4]
                lands = output[:,5:].astype(int)
                loc = boxes[0]         # 請考慮以後可能是 多個 batch ouptut
                land = output[:,5:][0] # 請考慮以後可能是 多個 batch ouptut
                if self.alignment is not None:
                    crop_img = align(info.img,land,padding=5)
                    if crop_img is None:
                        crop_img = info.img[loc[1]:loc[3],loc[0]:loc[2]]
                else:
                    crop_img = info.img[loc[1]:loc[3],loc[0]:loc[2]]
                # 過濾物件
                plateInfo = PlateInfo(info.track_id, info.img ,crop_img)
                self.plates_q.put(plateInfo)
#             with self.condition:
#                 s_img, s_boxes, s_confs, s_lands = info.img, boxes, confs, lands
#                 self.condition.notify()
        del self.trt_plate
        print('TrtThread: stopped...')

    def stop(self):
        self.running = False
        self.join()
        
class TrtCTCOCRThread(threading.Thread):
    """TrtThread
    This implements the child thread which continues to read images
    from cam (input) and to do TRT engine inferencing.  The child
    thread stores the input image and detection results into global
    variables and uses a condition varaiable to inform main thread.
    In other words, the TrtThread acts as the producer while the
    main thread is the consumer.
    """
    def __init__(self, condition, plates_q, results_q, budget=10):
        """__init__
        # Arguments
            condition: the condition variable used to notify main
                       thread about new frame and detection result
            cam: the camera object for reading input image frames
            model: a string, specifying the TRT SSD model
            conf_th: confidence threshold for detection
        """
        threading.Thread.__init__(self)
        self.condition = condition
        self.plates_q = plates_q
        self.results_q = results_q
        self.trt_ocr = None   # to be created when run
        self.running = False

    def run(self):
        """Run until 'running' flag is set to False by main thread.
        NOTE: CUDA context is created here, i.e. inside the thread
        which calls CUDA kernels.  In other words, creating CUDA
        context in __init__() doesn't work.
        """
#         global s_img, s_boxes, s_confs, s_lands

        print('TrtThread: loading the TRT CTC OCR engine...')
#         self.trt_ocr = TrtCtcOcr()
        self.trt_ocr = TrtCtcOcrSHM()
        
        
        print('TrtThread: TRT CTC OCR start running...')
        self.running = True
        while self.running:
            if(self.plates_q.empty()):
                time.sleep(1)
                continue
            info = self.plates_q.get()
            
            if(not isinstance(info, PlateInfo)):
                continue
            if(info.plate_img is None):
                continue
#             stime = time.time()
            try:
                plate = self.trt_ocr.recognition(info.plate_img)
            except Exception as e:
                print(e)
                plate = ""
#             print("Resnet32 CTC OCR 480*16 shared memory inference time", time.time() - stime)
            
            resultInfo = ResultInfo(info.track_id, plate, info.vechile_img, info.plate_img)
            self.results_q.put(resultInfo)
#             with self.condition:
#                 s_img, s_boxes, s_confs, s_lands = img, boxes, confs, lands
#                 self.condition.notify()
        del self.trt_ocr
        print('TrtThread: stopped...')

    def stop(self):
        self.running = False
        self.join()
        
        
class VehicleInfo(object):
    def __init__(self,track_id, img):
        self.track_id = track_id
        self.img = img
        
class PlateInfo(object):
    def __init__(self,track_id, vechile_img, plate_img):
        self.track_id = track_id
        self.vechile_img = vechile_img
        self.plate_img = plate_img
        
class ResultInfo(object):
    def __init__(self,track_id, plate,vechile_img, plate_img):
        self.track_id = track_id
        self.plate = plate
        self.vechile_img = vechile_img
        self.plate_img = plate_img
        
        
class EZLPR(object):    
    def __init__(self):
        self.vehiclesQueue = queue.Queue()
        self.platesQueue = queue.Queue()
        self.resultsQueue = queue.Queue()
        
        condition = threading.Condition() # condition 暫時用不到
        self.trtRetinaPlateThread = TrtRetinaPlateThread(condition, self.vehiclesQueue, self.platesQueue)
        self.trtCtcOcrThread =  TrtCTCOCRThread(condition, self.platesQueue, self.resultsQueue)
        self.trtRetinaPlateThread.start()  # start the child thread
        self.trtCtcOcrThread.start()
        
    def put(self, track_id, crop_img):
        vehicleInfo = VehicleInfo(track_id = track_id,img = crop_img)
        self.vehiclesQueue.put(vehicleInfo)
        
    def out(self):
        while not self.resultsQueue.empty():
            resultInfo = self.resultsQueue.get()
            yield resultInfo.track_id,resultInfo.plate, resultInfo.vechile_img, resultInfo.plate_img
            
    def __del__(self):
        self.trtRetinaPlateThread.stop()
        self.trtCtcOcrThread.stop()
