diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e0221e60a7a94e340d318532aee28e0fa9b469a7 --- /dev/null +++ b/bob/learn/tensorflow/dataset/__init__.py @@ -0,0 +1,71 @@ +import tensorflow as tf + + +DEFAULT_FEATURE = {'train/data': tf.FixedLenFeature([], tf.string), + 'train/label': tf.FixedLenFeature([], tf.int64)} + + + +def append_image_augmentation(image, gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + """ + Append to the current tensor some random image augmentation operation + + **Parameters** + gray_scale: + Convert to gray scale? + + output_shape: + If set, will randomly crop the image given the output shape + + random_flip: + Randomly flip an image horizontally (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right) + + random_brightness: + Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness) + + random_contrast: + Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast) + + random_saturation: + Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation) + + per_image_normalization: + Linearly scales image to have zero mean and unit norm. + + """ + + # Casting to float32 + image = tf.cast(image, tf.float32) + + if output_shape is not None: + assert output_shape.ndim == 2 + image = tf.image.resize_image_with_crop_or_pad(image, output_shape[0], output_shape[1]) + + if random_flip: + image = tf.image.random_flip_left_right(image) + + if random_brightness: + image = tf.image.random_brightness(image) + + if random_contrast: + image = tf.image.random_contrast(image) + + if random_saturation: + image = tf.image.random_saturation(image) + + if gray_scale: + image = tf.image.rgb_to_grayscale(image, name="rgb_to_gray") + #self.output_shape[3] = 1 + + # normalizing data + if per_image_normalization: + image = tf.image.per_image_standardization(image) + + return image + diff --git a/bob/learn/tensorflow/dataset/image.py b/bob/learn/tensorflow/dataset/image.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5930500b9d747ca5566fad9a93bd6fbd11f6b9 --- /dev/null +++ b/bob/learn/tensorflow/dataset/image.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + +import tensorflow as tf +from functools import partial +from . import append_image_augmentation + + +def shuffle_data_and_labels_image_augmentation(filenames, labels, data_shape, data_type, + batch_size, epochs=None, buffer_size=10**3, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + """ + Dump random batches from a list of image paths and labels: + + The list of files and labels should be in the same order e.g. + filenames = ['class_1_img1', 'class_1_img2', 'class_2_img1'] + labels = [0, 0, 1] + + + **Parameters** + + filenames: + List containing the path of the images + + labels: + List containing the labels (needs to be in EXACT same order as filenames) + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + batch_size: + Size of the batch + + epochs: + Number of epochs to be batched + + buffer_size: + Size of the shuffle bucket + + gray_scale: + Convert to gray scale? + + output_shape: + If set, will randomly crop the image given the output shape + + random_flip: + Randomly flip an image horizontally (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right) + + random_brightness: + Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness) + + random_contrast: + Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast) + + random_saturation: + Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation) + + per_image_normalization: + Linearly scales image to have zero mean and unit norm. + + """ + + dataset = create_dataset_from_path_augmentation(filenames, labels, data_shape, + data_type, + gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs) + + data, labels = dataset.make_one_shot_iterator().get_next() + return data, labels + + +def create_dataset_from_path_augmentation(filenames, labels, + data_shape, data_type, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + """ + Create dataset from a list of tf-record files + + **Parameters** + + filenames: + List containing the path of the images + + labels: + List containing the labels (needs to be in EXACT same order as filenames) + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + feature: + + """ + + parser = partial(image_augmentation_parser, + data_shape=data_shape, + data_type=data_type, + gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + dataset = tf.contrib.data.Dataset.from_tensor_slices((filenames, labels)) + dataset = dataset.map(parser) + return dataset + + +def image_augmentation_parser(filename, label, data_shape, data_type, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + + """ + Parses a single tf.Example into image and label tensors. + """ + + # Convert the image data from string back to the numbers + image = tf.cast(tf.image.decode_image(tf.read_file(filename)), tf.float32) + + # Reshape image data into the original shape + image = tf.reshape(image, data_shape) + + #Applying image augmentation + image = append_image_augmentation(image, gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + label = tf.cast(label, tf.int64) + + return image, label + diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py new file mode 100644 index 0000000000000000000000000000000000000000..8e4713110e265dbe6969899d680c6d668301ea8f --- /dev/null +++ b/bob/learn/tensorflow/dataset/tfrecords.py @@ -0,0 +1,281 @@ +from functools import partial +import tensorflow as tf +from . import append_image_augmentation, DEFAULT_FEATURE + + +def example_parser(serialized_example, feature, data_shape, data_type): + """ + Parses a single tf.Example into image and label tensors. + + """ + # Decode the record read by the reader + features = tf.parse_single_example(serialized_example, features=feature) + # Convert the image data from string back to the numbers + image = tf.decode_raw(features['train/data'], data_type) + # Cast label data into int64 + label = tf.cast(features['train/label'], tf.int64) + # Reshape image data into the original shape + image = tf.reshape(image, data_shape) + return image, label + + +def image_augmentation_parser(serialized_example, feature, data_shape, data_type, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + + """ + Parses a single tf.Example into image and label tensors. + + """ + # Decode the record read by the reader + features = tf.parse_single_example(serialized_example, features=feature) + # Convert the image data from string back to the numbers + image = tf.decode_raw(features['train/data'], data_type) + + # Reshape image data into the original shape + image = tf.reshape(image, data_shape) + + #Applying image augmentation + image = append_image_augmentation(image, gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + # Cast label data into int64 + label = tf.cast(features['train/label'], tf.int64) + return image, label + + +def read_and_decode(filename_queue, data_shape, data_type=tf.float32, + feature=None): + + """ + Simples parse possible for a tfrecord. + It assumes that you have the pair **train/data** and **train/label** + """ + + if feature is None: + feature = DEFAULT_FEATURE + # Define a reader and read the next record + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) + return example_parser(serialized_example, feature, data_shape, data_type) + + +def create_dataset_from_records(tfrecord_filenames, data_shape, data_type, + feature=None): + """ + Create dataset from a list of tf-record files + + **Parameters** + + tfrecord_filenames: + List containing the tf-record paths + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + feature: + + """ + + if feature is None: + feature = DEFAULT_FEATURE + dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames) + parser = partial(example_parser, feature=feature, data_shape=data_shape, + data_type=data_type) + dataset = dataset.map(parser) + return dataset + + +def create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape, data_type, + feature=None, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + """ + Create dataset from a list of tf-record files + + **Parameters** + + tfrecord_filenames: + List containing the tf-record paths + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + feature: + + """ + + + if feature is None: + feature = DEFAULT_FEATURE + dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames) + parser = partial(image_augmentation_parser, feature=feature, data_shape=data_shape, + data_type=data_type, + gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + dataset = dataset.map(parser) + return dataset + + +def shuffle_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=None, buffer_size=10**3, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + """ + Dump random batches from a list of tf-record files and applies some image augmentation + + **Parameters** + + tfrecord_filenames: + List containing the tf-record paths + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + batch_size: + Size of the batch + + epochs: + Number of epochs to be batched + + buffer_size: + Size of the shuffle bucket + + gray_scale: + Convert to gray scale? + + output_shape: + If set, will randomly crop the image given the output shape + + random_flip: + Randomly flip an image horizontally (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right) + + random_brightness: + Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness) + + random_contrast: + Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast) + + random_saturation: + Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation) + + per_image_normalization: + Linearly scales image to have zero mean and unit norm. + + """ + + dataset = create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape, + data_type, + gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs) + + data, labels = dataset.make_one_shot_iterator().get_next() + return data, labels + + +def shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=None, buffer_size=10**3): + """ + Dump random batches from a list of tf-record files + + **Parameters** + + tfrecord_filenames: + List containing the tf-record paths + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + batch_size: + Size of the batch + + epochs: + Number of epochs to be batched + + buffer_size: + Size of the shuffle bucket + + """ + + dataset = create_dataset_from_records(tfrecord_filenames, data_shape, + data_type) + dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs) + + data, labels = dataset.make_one_shot_iterator().get_next() + return data, labels + + +def batch_data_and_labels(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=1): + """ + Dump in order batches from a list of tf-record files + + **Parameters** + + tfrecord_filenames: + List containing the tf-record paths + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + batch_size: + Size of the batch + + epochs: + Number of epochs to be batched + + """ + dataset = create_dataset_from_records(tfrecord_filenames, data_shape, + data_type) + dataset = dataset.batch(batch_size).repeat(epochs) + + data, labels = dataset.make_one_shot_iterator().get_next() + return data, labels + diff --git a/bob/learn/tensorflow/test/test_estimator_scripts.py b/bob/learn/tensorflow/test/test_estimator_scripts.py index 38a8d89031f2da2eef92f3cbc73c835c52f2c083..d45c847fc9a92d5b6c41a11d1ca28d680ead09c3 100644 --- a/bob/learn/tensorflow/test/test_estimator_scripts.py +++ b/bob/learn/tensorflow/test/test_estimator_scripts.py @@ -14,7 +14,7 @@ from bob.learn.tensorflow.script.eval_generic import main as eval_generic dummy_tfrecord_config = datafile('dummy_verify_config.py', __name__) CONFIG = ''' import tensorflow as tf -from bob.learn.tensorflow.utils.tfrecords import shuffle_data_and_labels, \ +from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, \ batch_data_and_labels model_dir = "%(model_dir)s" diff --git a/bob/learn/tensorflow/test/test_onegraph_estimator.py b/bob/learn/tensorflow/test/test_onegraph_estimator.py new file mode 100755 index 0000000000000000000000000000000000000000..33cadd805f3b55f9e8866302f74d561913e1b465 --- /dev/null +++ b/bob/learn/tensorflow/test/test_onegraph_estimator.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + +import tensorflow as tf + +from bob.learn.tensorflow.network import dummy +from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer + +from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, batch_data_and_labels, shuffle_data_and_labels_image_augmentation + + +from bob.learn.tensorflow.dataset import append_image_augmentation +from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord +from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator +from bob.learn.tensorflow.utils import reproducible +from bob.learn.tensorflow.loss import mean_cross_entropy_loss + +import numpy + +import shutil +import os + + +tfrecord_train = "./train_mnist.tfrecord" +tfrecord_validation = "./validation_mnist.tfrecord" +model_dir = "./temp" + +learning_rate = 0.1 +data_shape = (28, 28, 1) # size of atnt images +data_type = tf.float32 +batch_size = 16 +validation_batch_size = 250 +epochs = 1 +steps = 5000 + + +def test_logitstrainer(): + # Trainer logits + try: + embedding_validation = False + trainer = LogitsTrainer(model_dir=model_dir, + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + loss_op=mean_cross_entropy_loss, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size) + run_logitstrainer_mnist(trainer, augmentation=True) + finally: + try: + os.unlink(tfrecord_train) + os.unlink(tfrecord_validation) + shutil.rmtree(model_dir, ignore_errors=True) + except Exception: + pass + + +def test_logitstrainer_embedding(): + try: + embedding_validation = True + trainer = LogitsTrainer(model_dir=model_dir, + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + loss_op=mean_cross_entropy_loss, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size) + run_logitstrainer_mnist(trainer) + finally: + try: + os.unlink(tfrecord_train) + os.unlink(tfrecord_validation) + shutil.rmtree(model_dir, ignore_errors=True) + except Exception: + pass + + +def test_logitstrainer_centerloss(): + + try: + embedding_validation = False + run_config = tf.estimator.RunConfig() + run_config = run_config.replace(save_checkpoints_steps=1000) + trainer = LogitsCenterLossTrainer( + model_dir=model_dir, + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size, + factor=0.01, + config=run_config) + + run_logitstrainer_mnist(trainer) + + # Checking if the centers were updated + sess = tf.Session() + checkpoint_path = tf.train.get_checkpoint_state(model_dir).model_checkpoint_path + saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True) + saver.restore(sess, tf.train.latest_checkpoint(model_dir)) + centers = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="center_loss/centers:0")[0] + assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0 + + + finally: + try: + os.unlink(tfrecord_train) + os.unlink(tfrecord_validation) + shutil.rmtree(model_dir, ignore_errors=True) + except Exception: + pass + + +def test_logitstrainer_centerloss_embedding(): + try: + embedding_validation = True + trainer = LogitsCenterLossTrainer( + model_dir=model_dir, + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size, + factor=0.01) + run_logitstrainer_mnist(trainer) + + # Checking if the centers were updated + sess = tf.Session() + checkpoint_path = tf.train.get_checkpoint_state(model_dir).model_checkpoint_path + saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True) + saver.restore(sess, tf.train.latest_checkpoint(model_dir)) + centers = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="center_loss/centers:0")[0] + assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0 + finally: + try: + os.unlink(tfrecord_train) + os.unlink(tfrecord_validation) + shutil.rmtree(model_dir, ignore_errors=True) + except Exception: + pass + + +def run_logitstrainer_mnist(trainer, augmentation=False): + + # Cleaning up + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + + # Creating tf records for mnist + train_data, train_labels, validation_data, validation_labels = load_mnist() + create_mnist_tfrecord(tfrecord_train, train_data, train_labels, n_samples=6000) + create_mnist_tfrecord(tfrecord_validation, validation_data, validation_labels, n_samples=validation_batch_size) + + def input_fn(): + + if augmentation: + return shuffle_data_and_labels_image_augmentation(tfrecord_train, data_shape, data_type, batch_size, epochs=epochs) + else: + return shuffle_data_and_labels(tfrecord_train, data_shape, data_type, + batch_size, epochs=epochs) + + + def input_fn_validation(): + return batch_data_and_labels(tfrecord_validation, data_shape, data_type, + validation_batch_size, epochs=1000) + + hooks = [LoggerHookEstimator(trainer, 16, 300), + + tf.train.SummarySaverHook(save_steps=1000, + output_dir=model_dir, + scaffold=tf.train.Scaffold(), + summary_writer=tf.summary.FileWriter(model_dir) )] + + trainer.train(input_fn, steps=steps, hooks=hooks) + + if not trainer.embedding_validation: + acc = trainer.evaluate(input_fn_validation) + assert acc['accuracy'] > 0.80 + else: + acc = trainer.evaluate(input_fn_validation) + assert acc['accuracy'] > 0.80 + + # Cleaning up + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + diff --git a/bob/learn/tensorflow/test/test_onegraph_model_fn.py b/bob/learn/tensorflow/test/test_onegraph_model_fn.py deleted file mode 100755 index 2b73e88a5035bc784777ac7787b2240ddf4387e0..0000000000000000000000000000000000000000 --- a/bob/learn/tensorflow/test/test_onegraph_model_fn.py +++ /dev/null @@ -1,162 +0,0 @@ -#!/usr/bin/env python -# vim: set fileencoding=utf-8 : -# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> - -import tensorflow as tf - -from bob.learn.tensorflow.network import dummy -from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer -from bob.learn.tensorflow.utils.tfrecords import shuffle_data_and_labels, batch_data_and_labels -from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord -from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator -from bob.learn.tensorflow.loss import mean_cross_entropy_loss -import numpy - -import shutil -import os - - -tfrecord_train = "./train_mnist.tfrecord" -tfrecord_validation = "./validation_mnist.tfrecord" -model_dir = "./temp" - -learning_rate = 0.1 -data_shape = (28, 28, 1) # size of atnt images -data_type = tf.float32 -batch_size = 16 -validation_batch_size = 1000 -epochs = 1 -steps = 2000 - - -def test_logitstrainer(): - run_logitstrainer(False) - - -def test_logitstrainer_embedding(): - run_logitstrainer(True) - - -def test_logitstrainer_centerloss(): - run_logitstrainer_centerloss(False) - - -def test_logitstrainer_centerloss_embedding(): - run_logitstrainer_centerloss(True) - - -def run_logitstrainer(embedding_validation): - - # Cleaning up - tf.reset_default_graph() - assert len(tf.global_variables()) == 0 - - # Creating tf records for mnist - train_data, train_labels, validation_data, validation_labels = load_mnist() - create_mnist_tfrecord(tfrecord_train, train_data, train_labels, n_samples=6000) - create_mnist_tfrecord(tfrecord_validation, validation_data, validation_labels, n_samples=1000) - - try: - - # Trainer logits - trainer = LogitsTrainer(model_dir=model_dir, - architecture=dummy, - optimizer=tf.train.GradientDescentOptimizer(learning_rate), - n_classes=10, - loss_op=mean_cross_entropy_loss, - embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size - ) - - def input_fn(): - return shuffle_data_and_labels(tfrecord_train, data_shape, data_type, - batch_size, epochs=epochs) - - def input_fn_validation(): - return batch_data_and_labels(tfrecord_validation, data_shape, data_type, - validation_batch_size, epochs=epochs) - - hooks = [LoggerHookEstimator(trainer, 16, 100)] - trainer.train(input_fn, steps=steps, hooks=hooks) - - if not embedding_validation: - - acc = trainer.evaluate(input_fn_validation) - assert acc['accuracy'] > 0.80 - else: - acc = trainer.evaluate(input_fn_validation) - assert acc['accuracy'] > 0.80 - - finally: - try: - os.unlink(tfrecord_train) - os.unlink(tfrecord_validation) - shutil.rmtree(model_dir) - except Exception: - pass - - # Cleaning up - tf.reset_default_graph() - assert len(tf.global_variables()) == 0 - - -def run_logitstrainer_centerloss(embedding_validation): - - # Cleaning up - tf.reset_default_graph() - assert len(tf.global_variables()) == 0 - - # Creating tf records for mnist - train_data, train_labels, validation_data, validation_labels = load_mnist() - create_mnist_tfrecord(tfrecord_train, train_data, train_labels, n_samples=6000) - create_mnist_tfrecord(tfrecord_validation, validation_data, validation_labels, n_samples=1000) - - try: - - # Trainer logits - trainer = LogitsCenterLossTrainer( - model_dir=model_dir, - architecture=dummy, - optimizer=tf.train.GradientDescentOptimizer(learning_rate), - n_classes=10, - embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size, - factor=0.01 - ) - - def input_fn(): - return shuffle_data_and_labels(tfrecord_train, data_shape, data_type, - batch_size, epochs=epochs) - - def input_fn_validation(): - return batch_data_and_labels(tfrecord_validation, data_shape, data_type, - validation_batch_size, epochs=epochs) - - hooks = [LoggerHookEstimator(trainer, 16, 100)] - trainer.train(input_fn, steps=steps, hooks=hooks) - - if not embedding_validation: - acc = trainer.evaluate(input_fn_validation) - assert acc['accuracy'] > 0.80 - else: - acc = trainer.evaluate(input_fn_validation) - assert acc['accuracy'] > 0.80 - - sess = tf.Session() - checkpoint_path = tf.train.get_checkpoint_state(model_dir).model_checkpoint_path - saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True) - saver.restore(sess, tf.train.latest_checkpoint(model_dir)) - centers = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="center_loss/centers:0")[0] - assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0 - - finally: - try: - os.unlink(tfrecord_train) - os.unlink(tfrecord_validation) - shutil.rmtree(model_dir) - except Exception: - pass - - # Cleaning up - tf.reset_default_graph() - assert len(tf.global_variables()) == 0 diff --git a/bob/learn/tensorflow/utils/tfrecords.py b/bob/learn/tensorflow/utils/tfrecords.py deleted file mode 100644 index 48da0740577c2a64e4e6f59b26dac959e0a0678f..0000000000000000000000000000000000000000 --- a/bob/learn/tensorflow/utils/tfrecords.py +++ /dev/null @@ -1,60 +0,0 @@ -from functools import partial -import tensorflow as tf - - -DEFAULT_FEATURE = {'train/data': tf.FixedLenFeature([], tf.string), - 'train/label': tf.FixedLenFeature([], tf.int64)} - - -def example_parser(serialized_example, feature, data_shape, data_type): - """Parses a single tf.Example into image and label tensors.""" - # Decode the record read by the reader - features = tf.parse_single_example(serialized_example, features=feature) - # Convert the image data from string back to the numbers - image = tf.decode_raw(features['train/data'], data_type) - # Cast label data into int64 - label = tf.cast(features['train/label'], tf.int64) - # Reshape image data into the original shape - image = tf.reshape(image, data_shape) - return image, label - - -def read_and_decode(filename_queue, data_shape, data_type=tf.float32, - feature=None): - if feature is None: - feature = DEFAULT_FEATURE - # Define a reader and read the next record - reader = tf.TFRecordReader() - _, serialized_example = reader.read(filename_queue) - return example_parser(serialized_example, feature, data_shape, data_type) - - -def create_dataset_from_records(tfrecord_filenames, data_shape, data_type, - feature=None): - if feature is None: - feature = DEFAULT_FEATURE - dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames) - parser = partial(example_parser, feature=feature, data_shape=data_shape, - data_type=data_type) - dataset = dataset.map(parser) - return dataset - - -def shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type, - batch_size, epochs=None, buffer_size=10**3): - dataset = create_dataset_from_records(tfrecord_filenames, data_shape, - data_type) - dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs) - - datas, labels = dataset.make_one_shot_iterator().get_next() - return datas, labels - - -def batch_data_and_labels(tfrecord_filenames, data_shape, data_type, - batch_size, epochs=1): - dataset = create_dataset_from_records(tfrecord_filenames, data_shape, - data_type) - dataset = dataset.batch(batch_size).repeat(epochs) - - datas, labels = dataset.make_one_shot_iterator().get_next() - return datas, labels