Skip to content
Snippets Groups Projects
Commit 1b13bc83 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created a dataset to fetch directories

parent 32e25f63
No related branches found
No related tags found
1 merge request!21Resolve "Adopt to the Estimators API"
Pipeline #
bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0.png

60 KiB

bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p02_i0_0.png

61.3 KiB

bob/learn/tensorflow/test/data/dummy_image_database/m304_01_p01_i0_0.png

70.8 KiB

bob/learn/tensorflow/test/data/dummy_image_database/m304_02_f12_i0_0.png

66.9 KiB

#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
from bob.learn.tensorflow.network import dummy
from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer
from bob.learn.tensorflow.dataset.image import shuffle_data_and_labels_image_augmentation
import pkg_resources
from bob.learn.tensorflow.dataset import append_image_augmentation
from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
from bob.learn.tensorflow.utils import reproducible
from bob.learn.tensorflow.loss import mean_cross_entropy_loss
import numpy
import shutil
import os
model_dir = "./temp"
learning_rate = 0.1
data_shape = (250, 250, 3) # size of atnt images
data_type = tf.float32
batch_size = 16
validation_batch_size = 250
epochs = 1
steps = 5000
def test_logitstrainer_images():
# Trainer logits
try:
embedding_validation = False
trainer = LogitsTrainer(model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size)
run_logitstrainer_images(trainer)
finally:
try:
os.unlink(tfrecord_train)
os.unlink(tfrecord_validation)
shutil.rmtree(model_dir, ignore_errors=True)
except Exception:
pass
def run_logitstrainer_images(trainer):
# Cleaning up
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')]
labels = [0, 0, 1, 1]
def input_fn():
return shuffle_data_and_labels_image_augmentation(filenames,labels, data_shape, data_type, batch_size, epochs=epochs)
def input_fn_validation():
return shuffle_data_and_labels_image_augmentation(filenames, labels, data_shape, data_type,
validation_batch_size, epochs=1000)
hooks = [LoggerHookEstimator(trainer, 16, 300),
tf.train.SummarySaverHook(save_steps=1000,
output_dir=model_dir,
scaffold=tf.train.Scaffold(),
summary_writer=tf.summary.FileWriter(model_dir) )]
trainer.train(input_fn, steps=steps, hooks=hooks)
if not trainer.embedding_validation:
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.80
else:
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.80
# Cleaning up
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment