import numpy as np
import os
from scipy.optimize import curve_fit
from sklearn.svm import SVR
from sklearn.model_selection import train_test_split # Import train_test_split function
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
import pickle

class Correcter(object):
    def __init__(self,model_path="./20210421_correcter.sav"):
#         self.clf=""
#         self.clf = self.fitting()
        self.clf = pickle.load(open(model_path,'rb'))
    # MAPE 計算函數
    def mean_absolute_percentage_error(self,y_true, y_pred): 
        return np.mean(np.abs((y_true - y_pred) / y_true)) * 100
    def fitting(self):
        # SVR
        # y_detect_c = x_black_k + x_black_c + x_detect_k
        # data perpare
        x = []
        y = []
#         stop_num = 3000
#         for root,folders,files in os.walk('./logs/20201230_tlinear'):
#             for f in files:
#                 if('.log' in f):
#                     log = open(os.path.join(root,f))
#                     lines = log.readlines()
#                     temp_values = f.strip('.log').split('-')
#                     min_temp = float(temp_values[0])
#                     max_temp = float(temp_values[1])
#                     count=0
#                     for line in lines:
#                         count+=1
#                         if(count > stop_num):
#                             break
#                         rowdata=line.strip('\n').split(' ')
#                         temp1 = int(rowdata[1])
#                         temp2 = int(rowdata[2])
#                         x.append([temp1,temp2,max_temp])
#                         y.append(min_temp)
#                         x.append([temp2,temp2,max_temp])
#                         y.append(max_temp)
        log = open(os.path.join("./logs/argumented_data.txt"))
        lines = log.readlines()
        for line in lines:
            rowdata=line.strip('\n').split(' ')
            detect_rowdata = float(rowdata[0])
            blackbody_rowdata = float(rowdata[1])
            blackbody_temp = float(rowdata[2])
            detect_temp = float(rowdata[3])
            x.append([detect_rowdata,blackbody_rowdata])
            y.append(blackbody_temp-detect_temp)          
        X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1)
        clf = SVR(kernel='poly',degree=1,gamma=0.001)
        clf.fit(X_train, y_train)
        pred2 = clf.predict(X_test)
        r2 = r2_score(y_test,pred2)
        mse = mean_squared_error(y_test,pred2)
        mape = self.mean_absolute_percentage_error(y_test,pred2)
        print("kernel: {}, r_squre: {}, MSE: {}, MAPE: {}".format('rbf',r2,mse,mape))
        return clf
    def predict(self,x_detect_k,x_black_k,x_black_c):
#         return x_black_c - self.clf.predict([[x_detect_k,x_black_k]])[0]
        return self.clf.predict([[x_detect_k,x_black_k,x_black_c]])[0]
#         a = 1.6806e-05
#         b = -0.1183
#         c = 238.9117
#         p_black_c = a*pow(x_black_k,2)+b*x_black_k+c
#         indiv = x_black_c - p_black_c
#         return a*pow(x_detect_k,2)+b*x_detect_k+c+indiv