#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>

import tensorflow as tf

from bob.learn.tensorflow.network import dummy
from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer

from bob.learn.tensorflow.dataset.image import shuffle_data_and_labels_image_augmentation
import pkg_resources

from bob.learn.tensorflow.dataset import append_image_augmentation
from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
from bob.learn.tensorflow.utils import reproducible
from bob.learn.tensorflow.loss import mean_cross_entropy_loss

import numpy

import shutil
import os

model_dir = "./temp"

learning_rate = 0.1
data_shape = (250, 250, 3)  # size of atnt images
data_type = tf.float32
batch_size = 16
validation_batch_size = 250
epochs = 1
steps = 5000


def test_logitstrainer_images():

    # Trainer logits
    try:
        embedding_validation = False
        trainer = LogitsTrainer(model_dir=model_dir,
                                architecture=dummy,
                                optimizer=tf.train.GradientDescentOptimizer(learning_rate),
                                n_classes=10,
                                loss_op=mean_cross_entropy_loss,
                                embedding_validation=embedding_validation,
                                validation_batch_size=validation_batch_size)
        run_logitstrainer_images(trainer)
    finally:
        try:
            os.unlink(tfrecord_train)
            os.unlink(tfrecord_validation)
            shutil.rmtree(model_dir, ignore_errors=True)
        except Exception:
            pass        

def run_logitstrainer_images(trainer):
    # Cleaning up
    tf.reset_default_graph()
    assert len(tf.global_variables()) == 0
    
    filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
                  pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
                  pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
                  pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')]
    labels = [0, 0, 1, 1]
    
    def input_fn():

        return shuffle_data_and_labels_image_augmentation(filenames,labels, data_shape, data_type, batch_size, epochs=epochs)
        

    def input_fn_validation():
        return shuffle_data_and_labels_image_augmentation(filenames, labels, data_shape, data_type,
                                                          validation_batch_size, epochs=1000)
    
    hooks = [LoggerHookEstimator(trainer, 16, 300),

             tf.train.SummarySaverHook(save_steps=1000,
                                       output_dir=model_dir,
                                       scaffold=tf.train.Scaffold(),
                                       summary_writer=tf.summary.FileWriter(model_dir) )]

    trainer.train(input_fn, steps=steps, hooks=hooks)

    if not trainer.embedding_validation:
        acc = trainer.evaluate(input_fn_validation)
        assert acc['accuracy'] > 0.80
    else:
        acc = trainer.evaluate(input_fn_validation)
        assert acc['accuracy'] > 0.80

    # Cleaning up
    tf.reset_default_graph()
    assert len(tf.global_variables()) == 0