import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from torch.utils.data import Dataset
from model.MTCNN_nets import PNet
import time
import copy
import torch.nn as nn

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias, 0.1)
        

def trainP(root_path,pretrained=''):
    try:
        if (not root_path.startswith('/notebooks')):
            path='/notebooks'
        else:    
            path='/'
        for i in root_path.split('/'):
            path=os.path.join(path,i)
            if i == 'MTCNN':
                break
        root_path=path
        sys.path.append(os.path.join(root_path,'train'))
        from Data_Loading import ListDataset
                
        train_path = root_path+'/data_preprocessing/anno_store/imglist_anno_12.txt'
        val_path = root_path+'/data_preprocessing/anno_store/imglist_anno_12_val.txt'
        batch_size = 64
        dataloaders = {'train': torch.utils.data.DataLoader(ListDataset(train_path), batch_size=batch_size, shuffle=True,num_workers=12),
                       'val': torch.utils.data.DataLoader(ListDataset(val_path), batch_size=batch_size, shuffle=True,num_workers=12)}
        dataset_sizes = {'train': len(ListDataset(train_path)), 'val': len(ListDataset(val_path))}
        print('training dataset loaded with length : {}'.format(len(ListDataset(train_path))))
        print('validation dataset loaded with length : {}'.format(len(ListDataset(val_path))))

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        if (device=="cpu"):
            print("Training with CPU")
        else:
            print("Training with GPU")

        # load the model and weights for initialization
        model = PNet(is_train=True)
        model.apply(weights_init)
        model=model.cuda()
        if pretrained != '':            
            print('loading pretrained model from %s' % pretrained)
            model.load_state_dict(torch.load(pretrained))
        print("Pnet loaded")

        train_logging_file = os.path.join(root_path,'train','Pnet_train_logging.txt')

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999),weight_decay=1e-4)
#         optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3,momentum=0.9,weight_decay=1e-7)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',cooldown=1,patience=2,verbose=True)
        since = time.time()

        best_model_wts = copy.deepcopy(model.state_dict())
        best_accuracy = 0.0
        best_loss = 100

        loss_cls = nn.CrossEntropyLoss(reduction="none").cuda()
        loss_offset = nn.MSELoss(reduction="none").cuda()
        loss_lands = nn.SmoothL1Loss(reduction="none").cuda()
        #reduction='sum'
        num_epochs = 100
        epoch=0
        toler=0
        max_toler=30
        while best_loss > 1e-3 and toler<max_toler:
            print('Epoch {}/{}'.format(epoch, num_epochs-1))
            print('-' * 10)
            epoch+=1
            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # set model to training mode
                else:
                    model.eval()  # set model to evaluate mode

                running_loss, running_loss_cls, running_loss_offset,running_loss_lands = 0.0, 0.0, 0.0,0.0
                running_correct = 0.0
                running_gt = 0.0
                
                # iterate over data                
                for i_batch, sample_batched in enumerate(dataloaders[phase]):
#                     if (i_batch % 10==0):
#                         print(i_batch)
                    input_images, gt_label, gt_offset ,gt_landmarks= sample_batched['input_img'], sample_batched[
                        'label'], sample_batched['bbox_target'],sample_batched['landmark']
                    input_images = input_images.cuda()
                    gt_label = gt_label.cuda()
                    # print('gt_label is ', gt_label)
                    gt_offset = gt_offset.type(torch.FloatTensor).cuda()
                    # print('gt_offset shape is ',gt_offset.shape)
                    gt_landmarks = gt_landmarks.type(torch.FloatTensor).cuda()
                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        pred_landmarks,pred_offsets, pred_label = model(input_images)
                        pred_offsets = torch.squeeze(pred_offsets)
                        pred_label = torch.squeeze(pred_label)
                        pred_landmarks = torch.squeeze(pred_landmarks)
                        # calculate the cls loss
                        # get the mask element which >= 0, only 0 and 1 can effect the detection loss
                        mask_cls = torch.ge(gt_label, 0)
                        valid_gt_label = gt_label[mask_cls]
                        valid_pred_label = pred_label[mask_cls]

                        # calculate the box loss
                        # get the mask element which != 0
                        unmask = torch.eq(gt_label, 0)
                        mask_offset = torch.eq(unmask, 0)
                        valid_gt_offset = gt_offset[mask_offset]
                        valid_pred_offset = pred_offsets[mask_offset]
                        
                        # calculate the landmark loss
                        # get the mask element which = -2
                        mask_land = torch.eq(gt_label, -2)
                        valid_gt_landmark = gt_landmarks[mask_land]
                        valid_pred_landmark = pred_landmarks[mask_land]
                        
                        loss = torch.tensor(0.0).cuda()
                        cls_loss, offset_loss,lands_loss = 0.0, 0.0,0.0
                        eval_correct = 0.0
                        num_gt = len(valid_gt_label)
                        keep_num=(int)(batch_size*0.7)
                        if len(valid_gt_label) != 0:
                            cls_loss=loss_cls(valid_pred_label, valid_gt_label)
                            if phase == 'train':
                                ohem_cls,idx=torch.sort(cls_loss, descending=True)
                                cls_loss=ohem_cls[:keep_num].sum()/keep_num
                                loss += 0.3* (cls_loss)  #0.02
                            else:
                                cls_loss=cls_loss.sum()/batch_size
                                loss += 0.3* cls_loss  #0.02                                
                            pred = torch.max(valid_pred_label, 1)[1]
                            eval_correct = (pred == valid_gt_label).sum().item()

                        if len(valid_gt_offset) != 0:
                            offset_loss=loss_offset(valid_pred_offset, valid_gt_offset)
                            if phase == 'train':
                                ohem_offset,idx=torch.sort(offset_loss, descending=True)
                                offset_loss=ohem_offset[:keep_num].sum()/keep_num
                                loss +=0.6* (offset_loss)
                            else:
                                offset_loss=offset_loss.sum()/batch_size
                                loss +=0.6* offset_loss
                        
                        if(len(valid_gt_landmark) != 0):
                            lands_loss=loss_lands(valid_pred_landmark,valid_gt_landmark)
                            if phase == 'train':
                                ohem_lands,idx=torch.sort(lands_loss, descending=True)
                                lands_loss=ohem_lands[:keep_num].sum()/keep_num
                                loss +=0.1* (lands_loss)
                            else:
                                lands_loss=lands_loss.sum()/batch_size
                                loss +=0.1*lands_loss
                        
                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                        # statistics
                        if phase == 'train':
                            running_loss += loss.item()*keep_num
                            running_loss_cls += cls_loss*keep_num
                            running_loss_offset += offset_loss*keep_num
                            running_loss_lands += lands_loss*keep_num
                        else:
                            running_loss += loss.item()*batch_size
                            running_loss_cls += cls_loss*batch_size
                            running_loss_offset += offset_loss*batch_size
                            running_loss_lands += lands_loss*batch_size
                        running_correct += eval_correct
                        running_gt += num_gt

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_loss_cls = running_loss_cls / dataset_sizes[phase]
                epoch_loss_offset = running_loss_offset / dataset_sizes[phase]
                epoch_loss_lands = running_loss_lands / dataset_sizes[phase]
                epoch_accuracy = running_correct / (running_gt + 1e-16)
                if phase == 'train':
                    scheduler.step(epoch_loss)
                

                print('{} Loss: {:.4f} accuracy: {:.4f} cls Loss: {:.4f} offset Loss: {:.4f} landmarks Loss: {:.4f} EarlyStop: {}/{}'
                      .format(phase, epoch_loss, epoch_accuracy, epoch_loss_cls, epoch_loss_offset,epoch_loss_lands,toler,max_toler))
                with open(train_logging_file, 'a') as f:
                    f.write('{} Loss: {:.4f} accuracy: {:.4f} cls Loss: {:.4f} offset Loss: {:.4f} landmarks Loss: {:.4f}'
                            .format(phase, epoch_loss, epoch_accuracy, epoch_loss_cls, epoch_loss_offset,epoch_loss_lands)+'\n')
                f.close()

                # deep copy the model
                if phase == 'val' and best_loss-epoch_loss >= 1e-4:
                    toler=0
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                elif phase == 'val':
                    toler+=1

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('Best loss: {:4f}'.format(best_loss))

        model.load_state_dict(best_model_wts)
        torch.save(model.state_dict(), os.path.join(root_path,'train','pnet_Weights'))
        print('model save in:',os.path.join(root_path,'train', 'pnet_Weights'))
        del(model)
        torch.cuda.empty_cache()
        return True
    except Exception as e:
        del(model)        
        torch.cuda.empty_cache()
        raise(e)
        return False