#!/usr/bin/python
# encoding: utf-8

import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import sampler
import torchvision.transforms as transforms
import lmdb
import six
import sys
from PIL import Image
import numpy as np
from crnn_pytorch_Gas.data_augment import *


class lmdbDataset(Dataset):

    def __init__(self, root=None, transform=None, target_transform=None):
        self.env = lmdb.open(
            root,
            max_readers=1,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False)

        if not self.env:
            print('cannot creat lmdb from %s' % (root))
            sys.exit(0)

        with self.env.begin(write=False) as txn:
            nSamples = int(txn.get('num-samples'.encode('utf-8')))
            self.nSamples = nSamples

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with self.env.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            imgbuf = txn.get(img_key.encode('utf-8'))

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:
                img = Image.open(buf).convert('L')
            except IOError:
                print('Corrupted image for %d' % index)
                return self[index + 1]

            if self.transform is not None:
                img = self.transform(img)

            label_key = 'label-%09d' % index
            label = txn.get(label_key.encode('utf-8'))

            if self.target_transform is not None:
                label = self.target_transform(label)

        return (img, label)

import cv2
import numpy as np
class resizeNormalize(object):
    def __init__(self, size,is_train=True, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation
        self.toTensor = transforms.ToTensor()
        self.is_train=is_train

    def __call__(self, img):
        if self.is_train:
            img = img.resize(self.size, self.interpolation)
            img = self.toTensor(img)
            img.sub_(0.5).div_(0.5)
            return img
        else:            
            img = cv2.resize(img,self.size, interpolation=cv2.INTER_LINEAR)
            img = self.toTensor(img)
            img.sub_(0.5).div_(0.5)
            return img



class randomSequentialSampler(sampler.Sampler):

    def __init__(self, data_source, batch_size):
        self.num_samples = len(data_source)
        self.batch_size = batch_size

    def __iter__(self):
        n_batch = len(self) // self.batch_size
        tail = len(self) % self.batch_size
        index = torch.LongTensor(len(self)).fill_(0)
        for i in range(n_batch):
            random_start = random.randint(0, len(self) - self.batch_size)
            batch_index = random_start + torch.range(0, self.batch_size - 1)
            index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
        # deal with tail
        if tail:
            random_start = random.randint(0, len(self) - self.batch_size)
            tail_index = random_start + torch.range(0, tail - 1)
            index[(i + 1) * self.batch_size:] = tail_index

        return iter(index)

    def __len__(self):
        return self.num_samples


class dataAugment(object):
    def __call__(self, img):
        if random.randrange(2): # do nothing
            return img
        ### inupt PIL image
        ### output PIL image after Agument
        # PIL formate to opencv formate
        open_cv_image = np.array(img)
        # Convert Gray to BGR 
        img = open_cv_image[:, :]
        img=cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2BGR)

        img=distort(img)# 隨機 亮度 對比 色調        
        if random.randrange(2):
            img=rand_GaussianBlur(img)#隨機模糊
        if random.randrange(2):
            img=AffineTrans(img)#隨機仿射變換
        if random.randrange(2):
            img=rand_Erasing(img)#隨機擦去
            
        image = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2GRAY))  
        return image

class alignCollate(object):

    def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1):
        self.imgH = imgH
        self.imgW = imgW
        self.keep_ratio = keep_ratio
        self.min_ratio = min_ratio

    def __call__(self, batch):
        images, labels = zip(*batch)

        imgH = self.imgH
        imgW = self.imgW
        if self.keep_ratio:
            ratios = []
            for image in images:
                w, h = image.size
                ratios.append(w / float(h))
            ratios.sort()
            max_ratio = ratios[-1]
            imgW = int(np.floor(max_ratio * imgH))
            imgW = max(imgH * self.min_ratio, imgW)  # assure imgH >= imgW

        dataA=dataAugment()
        transform = resizeNormalize((imgW, imgH))
        # 資料強化
        #images = [dataA(image) for image in images]
        images = [transform(image) for image in images]
        images = torch.cat([t.unsqueeze(0) for t in images], 0)

        return images, labels
