diff --git a/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0.png b/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..52d39487637b8a7ba460c93ecc9e1bb92e5ca42f Binary files /dev/null and b/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0.png differ diff --git a/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p02_i0_0.png b/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p02_i0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..0c7e298de460379d02de275c38ebc24840a258fa Binary files /dev/null and b/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p02_i0_0.png differ diff --git a/bob/learn/tensorflow/test/data/dummy_image_database/m304_01_p01_i0_0.png b/bob/learn/tensorflow/test/data/dummy_image_database/m304_01_p01_i0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..53c25af50711c607d2d05cb9566acfe2b140977d Binary files /dev/null and b/bob/learn/tensorflow/test/data/dummy_image_database/m304_01_p01_i0_0.png differ diff --git a/bob/learn/tensorflow/test/data/dummy_image_database/m304_02_f12_i0_0.png b/bob/learn/tensorflow/test/data/dummy_image_database/m304_02_f12_i0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..0fdf6b4d8fa118657bddd7b2b219d96085180d74 Binary files /dev/null and b/bob/learn/tensorflow/test/data/dummy_image_database/m304_02_f12_i0_0.png differ diff --git a/bob/learn/tensorflow/test/test_image_dataset.py b/bob/learn/tensorflow/test/test_image_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..5e0e1b1960de4f4337372cc3ae6f2afb3f481d6a --- /dev/null +++ b/bob/learn/tensorflow/test/test_image_dataset.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + +import tensorflow as tf + +from bob.learn.tensorflow.network import dummy +from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer + +from bob.learn.tensorflow.dataset.image import shuffle_data_and_labels_image_augmentation +import pkg_resources + +from bob.learn.tensorflow.dataset import append_image_augmentation +from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord +from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator +from bob.learn.tensorflow.utils import reproducible +from bob.learn.tensorflow.loss import mean_cross_entropy_loss + +import numpy + +import shutil +import os + +model_dir = "./temp" + +learning_rate = 0.1 +data_shape = (250, 250, 3) # size of atnt images +data_type = tf.float32 +batch_size = 16 +validation_batch_size = 250 +epochs = 1 +steps = 5000 + + +def test_logitstrainer_images(): + + # Trainer logits + try: + embedding_validation = False + trainer = LogitsTrainer(model_dir=model_dir, + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + loss_op=mean_cross_entropy_loss, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size) + run_logitstrainer_images(trainer) + finally: + try: + os.unlink(tfrecord_train) + os.unlink(tfrecord_validation) + shutil.rmtree(model_dir, ignore_errors=True) + except Exception: + pass + +def run_logitstrainer_images(trainer): + # Cleaning up + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + + filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')] + labels = [0, 0, 1, 1] + + def input_fn(): + + return shuffle_data_and_labels_image_augmentation(filenames,labels, data_shape, data_type, batch_size, epochs=epochs) + + + def input_fn_validation(): + return shuffle_data_and_labels_image_augmentation(filenames, labels, data_shape, data_type, + validation_batch_size, epochs=1000) + + hooks = [LoggerHookEstimator(trainer, 16, 300), + + tf.train.SummarySaverHook(save_steps=1000, + output_dir=model_dir, + scaffold=tf.train.Scaffold(), + summary_writer=tf.summary.FileWriter(model_dir) )] + + trainer.train(input_fn, steps=steps, hooks=hooks) + + if not trainer.embedding_validation: + acc = trainer.evaluate(input_fn_validation) + assert acc['accuracy'] > 0.80 + else: + acc = trainer.evaluate(input_fn_validation) + assert acc['accuracy'] > 0.80 + + # Cleaning up + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + +