from __future__ import print_function
from __future__ import division

import argparse
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import datetime
# from warpctc_pytorch import CTCLoss
from torch.nn import CTCLoss
import os
from crnn_pytorch_Gas import ocr_utils as utils
from crnn_pytorch_Gas import dataset

from crnn_pytorch_Gas.models import crnn as net
from crnn_pytorch_Gas import params
from crnn_pytorch_Gas.EarlyStopping import EarlyStopping

# import ocr_utils as utils
# import dataset
# from models import crnn as net
# import params
# from EarlyStopping import EarlyStopping


class crnn_train():
    def __init__(self):
        if not os.path.exists(params.expr_dir):
            os.makedirs(params.expr_dir)
        # ensure everytime the random is the same
        random.seed(params.manualSeed)
        np.random.seed(params.manualSeed)
        torch.manual_seed(params.manualSeed)
        cudnn.benchmark = True
        
        self.train_loader=None
        self.val_loader=None
        self.crnn=self.net_init()
        print(self.crnn)        
        # -----------------------------------------------
        """
        In this block
            Init some utils defined in utils.py
        """
        # Compute average for `torch.Variable` and `torch.Tensor`.
        self.loss_avg = utils.averager()

        # Convert between str and label.
        self.converter = utils.strLabelConverter(params.alphabet)
        # -----------------------------------------------
        """
        In this block
            criterion define
        """
        self.criterion = CTCLoss(reduction="mean")#reduction="sum" reduction="none" none for ohem

        # -----------------------------------------------
        """
        In this block
            Init some tensor
            Put tensor and net on cuda
            NOTE:
                image, text, length is used by both val and train
                becaues train and val will never use it at the same time.
        """
        self.image = torch.FloatTensor(params.batchSize, 3, params.imgH, params.imgW)
        self.text = torch.LongTensor(params.batchSize * 5)
        self.length = torch.LongTensor(params.batchSize)
        
        if params.cuda and torch.cuda.is_available():
            self.criterion = self.criterion.cuda()
            self.image = self.image.cuda()
            self.text = self.text.cuda()
            self.length = self.length.cuda()
            

            self.crnn = self.crnn.cuda()
            if params.multi_gpu:
                self.crnn = torch.nn.DataParallel(self.crnn, device_ids=range(params.ngpu))

        self.image = Variable(self.image)
        self.text = Variable(self.text)
        self.length = Variable(self.length)
        # -----------------------------------------------
        """
        In this block
            Setup optimizer
        """
        if params.adam:
            print("using adam")
            self.optimizer = optim.Adam(self.crnn.parameters(), lr=params.lr, betas=(params.beta1, 0.999))#,weight_decay=1e-5
        elif params.adadelta:
            print("using Adadelta")
            self.optimizer = optim.Adadelta(self.crnn.parameters())
        elif params.AdaGrad:
            print("using Adagrad")
            self.optimizer =optim.Adagrad(self.crnn.parameters(),lr=1e-3,lr_decay=1e-8,weight_decay=0,initial_accumulator_value=0)
        elif params.SGD:
            print("using SGD")
            self.optimizer = optim.SGD(self.crnn.parameters(), lr = params.lr,momentum=0.9,weight_decay=1e-7)# ,momentum=0.9,weight_decay=1e-5
        else:
            print("using RMSprop")
            self.optimizer = optim.RMSprop(self.crnn.parameters(), lr=params.lr)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min',cooldown=10,patience=5,verbose=True)
        
        # -----------------------------------------------
        """
        In this block
            Dealwith lossnan
            NOTE:
                I use different way to dealwith loss nan according to the torch version. 
        """
        if params.dealwith_lossnan:
            if torch.__version__ >= '1.1.0':
                """
                zero_infinity (bool, optional):
                    Whether to zero infinite losses and the associated gradients.
                    Default: ``False``
                    Infinite losses mainly occur when the inputs are too short
                    to be aligned to the targets.
                Pytorch add this param after v1.1.0 
                """
                self.criterion = CTCLoss(zero_infinity = True)
            else:
                """
                only when
                    torch.__version__ < '1.1.0'
                we use this way to change the inf to zero
                """
                self.crnn.register_backward_hook(self.crnn.backward_hook)

        # -----------------------------------------------

        self.best_acc=0.9
        
    # -----------------------------------------------
    """
    In this block
        Get train and val data_loader
    """
    def data_loader(self,args):
        # train
        train_dataset = dataset.lmdbDataset(root=args.trainroot)
        assert train_dataset
        if not params.random_sample:
            sampler = dataset.randomSequentialSampler(train_dataset, params.batchSize)
        else:
            sampler = None
        self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \
                shuffle=True, sampler=sampler, num_workers=int(params.workers), \
                collate_fn=dataset.alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))

        # val
        val_dataset = dataset.lmdbDataset(root=args.valroot, transform=dataset.resizeNormalize((params.imgW, params.imgH)))
        assert val_dataset
        self.val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=True, batch_size=params.batchSize, num_workers=int(params.workers))

        return self.train_loader, self.val_loader


    # -----------------------------------------------
    """
    In this block
        Net init
        Weight init
        Load pretrained model
    """
    def weights_init(self,m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02 )
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    def net_init(self):
        nclass = len(params.alphabet) + 1
        self.crnn = net.CRNN(params.imgH, params.nc, nclass, params.nh)
        self.crnn.apply(self.weights_init)
        if params.pretrained != '':
            print(params.expr_dir)
            print('loading pretrained model from %s' % params.pretrained)
            if params.multi_gpu:
                self.crnn = torch.nn.DataParallel(self.crnn)
            self.crnn.load_state_dict(torch.load(params.pretrained))
        return self.crnn
    

    def val(self,net, criterion):
        print('Start val')

        for p in self.crnn.parameters():
            p.requires_grad = False

        net.eval()
        val_iter = iter(self.val_loader)

        i = 0
        n_correct = 0
        loss_avg = utils.averager() # The global loss_avg is used by train

        max_iter = len(self.val_loader)
        for i in range(max_iter):
            data = val_iter.next()
            i += 1
            cpu_images, cpu_texts = data
            batch_size = cpu_images.size(0)
            utils.loadData(self.image, cpu_images)
            t, l = self.converter.encode(cpu_texts)
            utils.loadData(self.text, t)
            utils.loadData(self.length, l)

            preds = self.crnn(self.image)
            preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size))
            cost = criterion(preds, self.text, preds_size, self.length)
#             cost =(cost.sum())/batch_size
            loss_avg.add(cost)

            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False)
            cpu_texts_decode = []
#             f=open("fail_log.txt","w")
            for i in cpu_texts:
                cpu_texts_decode.append(i.decode('utf-8', 'strict'))
            for pred, target in zip(sim_preds, cpu_texts_decode):
                if pred == target:
                    n_correct += 1
#                 elif self.best_acc>=0.961:
#                     f.write("%s\n" % target)
#             f.close()

        raw_preds = self.converter.decode(preds.data, preds_size.data, raw=True)[:params.n_val_disp]
        for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts_decode):
            print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
        accuracy = n_correct / float(max_iter * params.batchSize)
        avg=loss_avg.val()    
        print('Val loss: %f, accuray: %f' % (avg, accuracy))
        
        if accuracy>self.best_acc:
            self.best_acc=accuracy
            val_save_path='{0}/netCRNN_{1}_{2}.pth'.format(params.expr_dir,str(accuracy),avg)
            print("Save val:",val_save_path)        
            torch.save(self.crnn.state_dict(), val_save_path)
        return avg


    def train(self,net, criterion, optimizer, train_iter):
        for p in net.parameters():
            p.requires_grad = True
        net.train()

        data = train_iter.next()
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(self.image, cpu_images)
        t, l = self.converter.encode(cpu_texts)
        utils.loadData(self.text, t)
        utils.loadData(self.length, l)

        optimizer.zero_grad()
        preds = net(self.image)
        preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, self.text, preds_size, self.length)        
#         ### ohem
        
        
#         sorted_ohem_loss, idx = torch.sort(cost, descending=True)
#         #再对loss进行降序排列
#         keep_num = min(sorted_ohem_loss.size()[0], batch_size)
#         #得到需要保留的loss数量
#         if keep_num < sorted_ohem_loss.size()[0]:
#         #这句的作用是如果保留数目小于现有loss总数，则进行筛选保留，否则全部保留
#             keep_idx_cuda = idx[:keep_num]
#             #保留到需要keep的数目
#             cost = sorted_ohem_loss[keep_idx_cuda]
#         cost = cost.sum() / keep_num

        
        cost.backward()
        optimizer.step()
        return cost


    def main(self,args=None):
        parser = argparse.ArgumentParser()
        parser.add_argument('-train', '--trainroot', required=True, help='path to train dataset')
        parser.add_argument('-val', '--valroot', required=True, help='path to val dataset')
        args = parser.parse_args(args)


        if torch.cuda.is_available() and not params.cuda:
            print("WARNING: You have a CUDA device, so you should probably set cuda in params.py to True")

        train_loader, val_loader = self.data_loader(args)
        save_path='{0}/netCRNN_{1}.pth'.format(params.expr_dir,str(datetime.date.today()))
        early_stopping = EarlyStopping(patience=2000, verbose=True,delta=0,path=save_path)
        lr_reset=0
        step_ls=0
        for epoch in range(params.nepoch):
            temp_loss=0
            train_iter = iter(train_loader)
            i = 0
            while i < len(train_loader):
                cost = self.train(self.crnn, self.criterion, self.optimizer,train_iter)
                self.loss_avg.add(cost)
                i += 1
                
            temp_loss=self.loss_avg.val()
            self.scheduler.step(temp_loss)

            print('[%d/%d][%d/%d] Loss: %f' %
              (epoch, params.nepoch, i, len(train_loader), self.loss_avg.val()))                    
            self.loss_avg.reset()
            
            if temp_loss <= (1e-1)*3: #and temp_loss <= (1e-2)*3
                if temp_loss-step_ls>(1e-2)*5 or (epoch % 10 )==0:
                    step_ls=temp_loss
                    avg=self.val(self.crnn, CTCLoss(reduction="mean")) #CTCLoss(reduction="mean")
            #                     self.scheduler.step(avg)        
                    early_stopping(avg,self.crnn)

#             if early_stopping.counter==0:
#                 lr_reset=0
#             if early_stopping.counter-lr_reset >= 50:
#                 lr_reset=early_stopping.counter
#                 print("目前 lr:",self.optimizer.param_groups[0]['lr'])
#                 self.optimizer.param_groups[0]['lr']=1e-3
#                 print("調整後 lr:",self.optimizer.param_groups[0]['lr'])

#             if early_stopping.early_stop:
#                 print("Early stopping")
#                 break    
                '''

                if i % params.displayInterval == 0:
                    temp_loss=self.loss_avg.val()
                    self.scheduler.step(temp_loss)

                    print('[%d/%d][%d/%d] Loss: %f' %
                      (epoch, params.nepoch, i, len(train_loader), self.loss_avg.val()))                    
                    self.loss_avg.reset()

    #             if i % params.valInterval == 0:
    #                 avg=val(crnn, criterion)
    #                 early_stopping(avg,crnn)
    #                 if early_stopping.early_stop:
    #                     print("Early stopping")
    #                     break
                if i % params.valInterval == 0 and temp_loss <= (1e-1)*1: #and temp_loss <= (1e-2)*3
                    avg=self.val(self.crnn, CTCLoss(reduction="mean")) #CTCLoss(reduction="mean")
#                     self.scheduler.step(avg)        
                    early_stopping(avg,self.crnn)
                    
                    if early_stopping.counter==0:
                        lr_reset=0
                    if early_stopping.counter-lr_reset >= 50:
                        lr_reset=early_stopping.counter
                        print("目前 lr:",self.optimizer.param_groups[0]['lr'])
                        self.optimizer.param_groups[0]['lr']=1e-3
                        print("調整後 lr:",self.optimizer.param_groups[0]['lr'])
                    
                    if early_stopping.early_stop:
                        print("Early stopping")
                        break
                '''

