
# coding: utf-8

# In[1]:


import cv2
from sklearn.externals import joblib
from skimage.transform import resize
from sklearn.preprocessing import LabelEncoder
from keras.models import model_from_json
import numpy as np


# In[2]:


# 載入人臉偵測cascade分類器
face_cascade = cv2.CascadeClassifier("../model/cv2/haarcascade_frontalface_alt2.xml")
print("人臉偵測cascade分類器載入完成")


# In[3]:


# 載入Facenet預測模型
model=model_from_json(open("../model/keras/facenet_model.json","r").read())
model.load_weights("../model/keras/facenet_weights.h5")
print("Facenet預測模型載入完成")
# model.summary()


# In[4]:


# 載入SVM分類器
clf=joblib.load('../model/20190622181256/20190622181256.pkl')
# 載入LabelEncoder
le=LabelEncoder()
le.classes_ =np.load('../model/20190622181256/classes.npy')
print("SVM分類器載入完成")

# In[5]:


# 影像預處理
def prewhiten(x):
    if x.ndim == 4:
        axis = (1, 2, 3)
        size = x[0].size
    elif x.ndim == 3:
        axis = (0, 1, 2)
        size = x.size
    else:
        print(x.ndim)
        raise ValueError('Dimension should be 3 or 4')

    mean = np.mean(x, axis=axis, keepdims=True)
    std = np.std(x, axis=axis, keepdims=True)
    std_adj = np.maximum(std, 1.0/np.sqrt(size))
    y = (x - mean) / std_adj
    return y

def l2_normalize(x, axis=-1, epsilon=1e-10):
    output = x / np.sqrt(np.maximum(np.sum(np.square(x), axis=axis, keepdims=True), epsilon))
    return output


# In[6]:


# 人臉偵測處理回傳結果矩陣
image_size=160
def face_cropped(img,faces, margin):    
    aligned_images = []
    for f in faces:
        (x, y, w, h) = f
        cropped = img[y-margin//2:y+h+margin//2,x-margin//2:x+w+margin//2, :]
        aligned = resize(cropped, (image_size, image_size), mode='reflect')
        aligned_images.append(aligned)
            
    return np.array(aligned_images)


# In[7]:


# 取得人臉Facenet預測之特徵值
def calc_embs(faces, margin=10, batch_size=1):
    aligned_images = prewhiten(faces)
    pd = []
    for start in range(0, len(aligned_images), batch_size):
        pd.append(model.predict_on_batch(aligned_images[start:start+batch_size]))
    embs = l2_normalize(np.concatenate(pd))

    return embs


# In[11]:


# 人臉辨識推斷
def infer(le, clf, img):
    faces = face_cascade.detectMultiScale(img,scaleFactor=1.1,minNeighbors=3)
    if(len(faces)==0):
        return '偵測不到人臉請重新調整'
    
    embs = calc_embs(face_cropped(img,faces,10))
#     pred = le.inverse_transform(clf.predict(embs))
    pred=get_labels(le,clf,embs)
    return [faces,pred]
# Labels 解析
def get_labels(le,clf,embs):
    socres=clf.predict_proba(embs)
    results=[]
    for s in socres:
        print(s)
        if(s[s.argmax()]>0.5):
            results.append(le.inverse_transform([s.argmax()])[0])
        else:
            results.append('Unknow')
    return results