import math
import MNN
import cv2
import numpy as np

class AntiSpoofing:
    def __init__(self,model_path="../model/4_0_0_80x80_MiniFASNetV1SE.mnn"):
        self.interpreter = MNN.Interpreter(model_path)
        self.session = self.interpreter.createSession({'numThread':4})
        self.input_tensor = self.interpreter.getSessionInput(self.session)
    def predict(self,image):
        image = cv2.resize(image, (80,80))
        image = image.transpose((2, 0, 1))
        image = image.astype(np.float32)
        tmp_input = MNN.Tensor((1, 3, 80,80), MNN.Halide_Type_Float, image, MNN.Tensor_DimensionType_Caffe)
        self.input_tensor.copyFrom(tmp_input)
        self.interpreter.runSession(self.session)
        isLive = self.interpreter.getSessionOutput(self.session, "Reshape176").getData()
        isLive = self.softmax_py(isLive)
        return isLive[1]
    def softmax_py(self,logits_data):
        logits_exp = [math.exp(i) for i in logits_data]

        sum_logits_exp = sum(logits_exp)

        softmax = [round(i/sum_logits_exp,3) for i in logits_exp]
        return softmax
    def scale_box(self,img_h,img_w,box,scale=2):
        x1,y1,x2,y2 = box
        new_x1 = max(scale*x1 - x2,0)
        new_y1 = max(scale*y1 - y2,0)
        new_x2 = min(scale*x2 - x1,img_w)
        new_y2 = min(scale*y2 - y1,img_h)
        return [new_x1,new_y1,new_x2,new_y2]