diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py index 7c0af9b4e73d12e1372baa2a8bc03b83ba6f608a..a742b1e092572a4ab5ce5158d957bb96109645d4 100644 --- a/bob/learn/tensorflow/dataset/__init__.py +++ b/bob/learn/tensorflow/dataset/__init__.py @@ -78,6 +78,8 @@ def append_image_augmentation(image, # Casting to float32 image = tf.cast(image, tf.float32) + # FORCING A SEED FOR THE RANDOM OPERATIONS + tf.set_random_seed(0) if output_shape is not None: assert len(output_shape) == 2 diff --git a/bob/learn/tensorflow/dataset/image.py b/bob/learn/tensorflow/dataset/image.py index e85f2a12a9f7190111dbd6c115a5f40cfb9ec2a6..c5888b8da966165e486e62251c975e3452160f81 100644 --- a/bob/learn/tensorflow/dataset/image.py +++ b/bob/learn/tensorflow/dataset/image.py @@ -25,7 +25,7 @@ def shuffle_data_and_labels_image_augmentation(filenames, extension=None): """ Dump random batches from a list of image paths and labels: - + The list of files and labels should be in the same order e.g. filenames = ['class_1_img1', 'class_1_img2', 'class_2_img1'] labels = [0, 0, 1] @@ -34,28 +34,28 @@ def shuffle_data_and_labels_image_augmentation(filenames, filenames: List containing the path of the images - + labels: List containing the labels (needs to be in EXACT same order as filenames) - + data_shape: Samples shape saved in the tf-record - + data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) - + batch_size: Size of the batch - + epochs: Number of epochs to be batched - + buffer_size: Size of the shuffle bucket gray_scale: Convert to gray scale? - + output_shape: If set, will randomly crop the image given the output shape @@ -79,7 +79,7 @@ def shuffle_data_and_labels_image_augmentation(filenames, extension: If None, will load files using `tf.image.decode..` if set to `hdf5`, will load with `bob.io.base.load` - + """ dataset = create_dataset_from_path_augmentation( @@ -118,23 +118,23 @@ def create_dataset_from_path_augmentation(filenames, extension=None): """ Create dataset from a list of tf-record files - + **Parameters** - + filenames: List containing the path of the images - + labels: List containing the labels (needs to be in EXACT same order as filenames) - + data_shape: Samples shape saved in the tf-record - + data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) - + feature: - + """ parser = partial( @@ -151,7 +151,7 @@ def create_dataset_from_path_augmentation(filenames, per_image_normalization=per_image_normalization, extension=extension) - dataset = tf.contrib.data.Dataset.from_tensor_slices((filenames, labels)) + dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.map(parser) return dataset diff --git a/bob/learn/tensorflow/dataset/siamese_image.py b/bob/learn/tensorflow/dataset/siamese_image.py index 1c8ca7a0f75a7b1a3b4844e1d1ffba6d00ef29c4..185ac696d4ee0f5ecfed0b524febbb057b93f54c 100644 --- a/bob/learn/tensorflow/dataset/siamese_image.py +++ b/bob/learn/tensorflow/dataset/siamese_image.py @@ -25,43 +25,43 @@ def shuffle_data_and_labels_image_augmentation(filenames, extension=None): """ Dump random batches for siamese networks from a list of image paths and labels: - + The list of files and labels should be in the same order e.g. filenames = ['class_1_img1', 'class_1_img2', 'class_2_img1'] labels = [0, 0, 1] - + The batches returned with tf.Session.run() with be in the following format: - **data** a dictionary containing the keys ['left', 'right'], each one representing + **data** a dictionary containing the keys ['left', 'right'], each one representing one element of the pair and **labels** which is [0, 1] where 0 is the genuine pair and 1 is the impostor pair. - + **Parameters** filenames: List containing the path of the images - + labels: List containing the labels (needs to be in EXACT same order as filenames) - + data_shape: Samples shape saved in the tf-record - + data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) - + batch_size: Size of the batch - + epochs: Number of epochs to be batched - + buffer_size: Size of the shuffle bucket gray_scale: Convert to gray scale? - + output_shape: If set, will randomly crop the image given the output shape @@ -79,10 +79,10 @@ def shuffle_data_and_labels_image_augmentation(filenames, random_rotate: Randomly rotate face images between -5 and 5 degrees - + per_image_normalization: - Linearly scales image to have zero mean and unit norm. - + Linearly scales image to have zero mean and unit norm. + extension: If None, will load files using `tf.image.decode..` if set to `hdf5`, will load with `bob.io.base.load` """ @@ -122,33 +122,33 @@ def create_dataset_from_path_augmentation(filenames, extension=None): """ Create dataset from a list of tf-record files - + **Parameters** - + filenames: List containing the path of the images - + labels: List containing the labels (needs to be in EXACT same order as filenames) - + data_shape: Samples shape saved in the tf-record - + data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) - + batch_size: Size of the batch - + epochs: Number of epochs to be batched - + buffer_size: Size of the shuffle bucket gray_scale: Convert to gray scale? - + output_shape: If set, will randomly crop the image given the output shape @@ -168,11 +168,11 @@ def create_dataset_from_path_augmentation(filenames, Randomly rotate face images between -10 and 10 degrees per_image_normalization: - Linearly scales image to have zero mean and unit norm. - + Linearly scales image to have zero mean and unit norm. + extension: If None, will load files using `tf.image.decode..` if set to `hdf5`, will load with `bob.io.base.load` - + """ parser = partial( @@ -191,7 +191,7 @@ def create_dataset_from_path_augmentation(filenames, left_data, right_data, siamese_labels = siamease_pairs_generator( filenames, labels) - dataset = tf.contrib.data.Dataset.from_tensor_slices( + dataset = tf.data.Dataset.from_tensor_slices( (left_data, right_data, siamese_labels)) dataset = dataset.map(parser) return dataset diff --git a/bob/learn/tensorflow/dataset/triplet_image.py b/bob/learn/tensorflow/dataset/triplet_image.py index 5d56b75e84d3c188e651115a5c061831c76c3832..bfd10632216aeb50cbf78a9c1187d9cf734fb510 100644 --- a/bob/learn/tensorflow/dataset/triplet_image.py +++ b/bob/learn/tensorflow/dataset/triplet_image.py @@ -25,41 +25,41 @@ def shuffle_data_and_labels_image_augmentation(filenames, extension=None): """ Dump random batches for triplee networks from a list of image paths and labels: - + The list of files and labels should be in the same order e.g. filenames = ['class_1_img1', 'class_1_img2', 'class_2_img1'] labels = [0, 0, 1] - + The batches returned with tf.Session.run() with be in the following format: **data** a dictionary containing the keys ['anchor', 'positive', 'negative']. - + **Parameters** filenames: List containing the path of the images - + labels: List containing the labels (needs to be in EXACT same order as filenames) - + data_shape: Samples shape saved in the tf-record - + data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) - + batch_size: Size of the batch - + epochs: Number of epochs to be batched - + buffer_size: Size of the shuffle bucket gray_scale: Convert to gray scale? - + output_shape: If set, will randomly crop the image given the output shape @@ -80,10 +80,10 @@ def shuffle_data_and_labels_image_augmentation(filenames, per_image_normalization: Linearly scales image to have zero mean and unit norm. - + extension: If None, will load files using `tf.image.decode..` if set to `hdf5`, will load with `bob.io.base.load` - + """ dataset = create_dataset_from_path_augmentation( @@ -123,23 +123,23 @@ def create_dataset_from_path_augmentation(filenames, extension=None): """ Create dataset from a list of tf-record files - + **Parameters** - + filenames: List containing the path of the images - + labels: List containing the labels (needs to be in EXACT same order as filenames) - + data_shape: Samples shape saved in the tf-record - + data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) - + feature: - + """ parser = partial( @@ -159,7 +159,7 @@ def create_dataset_from_path_augmentation(filenames, anchor_data, positive_data, negative_data = triplets_random_generator( filenames, labels) - dataset = tf.contrib.data.Dataset.from_tensor_slices( + dataset = tf.data.Dataset.from_tensor_slices( (anchor_data, positive_data, negative_data)) dataset = dataset.map(parser) return dataset diff --git a/bob/learn/tensorflow/estimators/Logits.py b/bob/learn/tensorflow/estimators/Logits.py index dcb24440793110307107210d77c46861b3442fa2..499e431160b5b167c5594f9dd8a85add614b87a9 100755 --- a/bob/learn/tensorflow/estimators/Logits.py +++ b/bob/learn/tensorflow/estimators/Logits.py @@ -99,7 +99,8 @@ class Logits(estimator.Estimator): validation_batch_size=None, params=None, extra_checkpoint=None, - apply_moving_averages=True): + apply_moving_averages=True, + add_histograms=None): self.architecture = architecture self.optimizer = optimizer @@ -114,64 +115,17 @@ class Logits(estimator.Estimator): check_features(features) data = features['data'] key = features['key'] - # Configure the Training Op (for TRAIN mode) - if mode == tf.estimator.ModeKeys.TRAIN: - - # Building the training graph - - # Checking if we have some variables/scope that we may want to shut down - trainable_variables = get_trainable_variables( - self.extra_checkpoint) - prelogits = self.architecture( - data, mode=mode, - trainable_variables=trainable_variables)[0] - logits = append_logits(prelogits, n_classes) - - if self.extra_checkpoint is not None: - tf.contrib.framework.init_from_checkpoint( - self.extra_checkpoint["checkpoint_path"], - self.extra_checkpoint["scopes"]) - - global_step = tf.train.get_or_create_global_step() - - # Compute the moving average of all individual losses and the total loss. - if apply_moving_averages: - variable_averages = tf.train.ExponentialMovingAverage( - 0.9999, global_step) - variable_averages_op = variable_averages.apply( - tf.trainable_variables()) - else: - variable_averages_op = tf.no_op(name='noop') - - with tf.control_dependencies([variable_averages_op]): - - # Compute Loss (for both TRAIN and EVAL modes) - self.loss = self.loss_op(logits=logits, labels=labels) - - # Compute the moving average of all individual losses and the total loss. - loss_averages = tf.train.ExponentialMovingAverage( - 0.9, name='avg') - loss_averages_op = loss_averages.apply( - tf.get_collection(tf.GraphKeys.LOSSES)) - - for l in tf.get_collection(tf.GraphKeys.LOSSES): - tf.summary.scalar(l.op.name + "_averaged", - loss_averages.average(l)) - - global_step = tf.train.get_or_create_global_step() - train_op = tf.group( - self.optimizer.minimize( - self.loss, global_step=global_step), - variable_averages_op, loss_averages_op) - - return tf.estimator.EstimatorSpec( - mode=mode, loss=self.loss, train_op=train_op) - # Building the training graph for PREDICTION OR VALIDATION - prelogits = self.architecture(data, mode=mode)[0] + # Checking if we have some variables/scope that we may want to shut + # down + trainable_variables = get_trainable_variables( + self.extra_checkpoint, mode=mode) + prelogits = self.architecture( + data, mode=mode, trainable_variables=trainable_variables)[0] logits = append_logits(prelogits, n_classes) - if self.embedding_validation: + if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN: + # Compute the embeddings embeddings = tf.nn.l2_normalize(prelogits, 1) predictions = { @@ -179,13 +133,13 @@ class Logits(estimator.Estimator): "key": key, } else: - probabilities = tf.nn.softmax(logits, name="softmax_tensor") predictions = { # Generate predictions (for PREDICT and EVAL mode) "classes": tf.argmax(input=logits, axis=1), # Add `softmax_tensor` to the graph. It is used for PREDICT # and by the `logging_hook`. - "probabilities": probabilities, + "probabilities": tf.nn.softmax( + logits, name="softmax_tensor"), "key": key, } @@ -193,31 +147,88 @@ class Logits(estimator.Estimator): return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions) - # IF Validation - self.loss = self.loss_op(logits=logits, labels=labels) - - if self.embedding_validation: + if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN: predictions_op = predict_using_tensors( predictions["embeddings"], labels, num=validation_batch_size) - eval_metric_ops = { - "accuracy": - tf.metrics.accuracy( - labels=labels, predictions=predictions_op) - } - return tf.estimator.EstimatorSpec( - mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops) - else: - # Add evaluation metrics (for EVAL mode) - eval_metric_ops = { - "accuracy": - tf.metrics.accuracy( - labels=labels, predictions=predictions["classes"]) - } + predictions_op = predictions["classes"] + + accuracy = tf.metrics.accuracy( + labels=labels, predictions=predictions_op) + metrics = {'accuracy': accuracy} + + if mode == tf.estimator.ModeKeys.EVAL: + self.loss = self.loss_op(logits=logits, labels=labels) return tf.estimator.EstimatorSpec( - mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops) + mode=mode, + predictions=predictions, + loss=self.loss, + train_op=None, + eval_metric_ops=metrics) + + # restore the model from an extra_checkpoint + if extra_checkpoint is not None: + tf.train.init_from_checkpoint( + ckpt_dir_or_file=extra_checkpoint["checkpoint_path"], + assignment_map=extra_checkpoint["scopes"], + ) + + global_step = tf.train.get_or_create_global_step() + + # Compute the moving average of all individual losses and the + # total loss. + if apply_moving_averages: + variable_averages = tf.train.ExponentialMovingAverage( + 0.9999, global_step) + variable_averages_op = variable_averages.apply( + tf.trainable_variables()) + else: + variable_averages_op = tf.no_op(name='noop') + + # Some layer like tf.layers.batch_norm need this: + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + + with tf.control_dependencies([variable_averages_op] + update_ops): + + # Calculate Loss + self.loss = self.loss_op(logits=logits, labels=labels) + + # Compute the moving average of all individual losses + # and the total loss. + loss_averages = tf.train.ExponentialMovingAverage( + 0.9, name='avg') + loss_averages_op = loss_averages.apply( + tf.get_collection(tf.GraphKeys.LOSSES)) + + train_op = tf.group( + self.optimizer.minimize( + self.loss, global_step=global_step), + variable_averages_op, loss_averages_op) + + # Log accuracy and loss + with tf.name_scope('train_metrics'): + tf.summary.scalar('accuracy', accuracy[1]) + tf.summary.scalar('loss', self.loss) + for l in tf.get_collection(tf.GraphKeys.LOSSES): + tf.summary.scalar(l.op.name + "_averaged", + loss_averages.average(l)) + + # add histograms summaries + if add_histograms == 'all': + for v in tf.all_variables(): + tf.summary.histogram(v.name, v) + elif add_histograms == 'train': + for v in tf.trainable_variables(): + tf.summary.histogram(v.name, v) + + return tf.estimator.EstimatorSpec( + mode=mode, + predictions=predictions, + loss=self.loss, + train_op=train_op, + eval_metric_ops=metrics) super(Logits, self).__init__( model_fn=_model_fn, diff --git a/bob/learn/tensorflow/estimators/__init__.py b/bob/learn/tensorflow/estimators/__init__.py index 59f0b905c79a50f3d81bf5fabb88551043a15a39..1f9c1d64b3d0abd758d10ba3fe2cb8fa515b73a4 100644 --- a/bob/learn/tensorflow/estimators/__init__.py +++ b/bob/learn/tensorflow/estimators/__init__.py @@ -13,7 +13,8 @@ def check_features(features): return True -def get_trainable_variables(extra_checkpoint): +def get_trainable_variables(extra_checkpoint, + mode=tf.estimator.ModeKeys.TRAIN): """ Given the extra_checkpoint dictionary provided to the estimator, extract the content of "trainable_variables" e. @@ -24,15 +25,20 @@ def get_trainable_variables(extra_checkpoint): Parameters ---------- - extra_checkpoint: dict - The `extra_checkpoint dictionary provided to the estimator + extra_checkpoint : dict + The `extra_checkpoint dictionary provided to the estimator + mode + The estimator mode. TRAIN, EVAL, and PREDICT. If not TRAIN, None is + returned. Returns ------- - Returns `None` if `trainable_variables` is not in extra_checkpoint; - otherwise returns the content of `extra_checkpoint + Returns `None` if `trainable_variables` is not in extra_checkpoint; + otherwise returns the content of `extra_checkpoint """ + if mode != tf.estimator.ModeKeys.TRAIN: + return None # If you don't set anything, everything is trainable if extra_checkpoint is None or "trainable_variables" not in extra_checkpoint: diff --git a/bob/learn/tensorflow/examples/mnist/mnist_config.py b/bob/learn/tensorflow/examples/mnist/mnist_config.py index b6780ff74d73fce28e387161de0b11413c6a798d..38e8da80cd299e97e80acb1843462e0a1fa3b120 100644 --- a/bob/learn/tensorflow/examples/mnist/mnist_config.py +++ b/bob/learn/tensorflow/examples/mnist/mnist_config.py @@ -75,7 +75,7 @@ def input_fn(mode, batch_size=1): # Map example_parser over dataset, and batch results by up to batch_size dataset = dataset.map( - example_parser, num_threads=1, output_buffer_size=batch_size) + example_parser, num_parallel_calls=1).prefetch(batch_size) dataset = dataset.batch(batch_size) images, labels, keys = dataset.make_one_shot_iterator().get_next() diff --git a/bob/learn/tensorflow/network/JointIncResV2Simple.py b/bob/learn/tensorflow/network/JointIncResV2Simple.py index 3044de704ccc6740741bc6101f5c88adea41f668..5ea81dc23c91b5134e7b0336ba23b99bb3e1d01b 100644 --- a/bob/learn/tensorflow/network/JointIncResV2Simple.py +++ b/bob/learn/tensorflow/network/JointIncResV2Simple.py @@ -4,13 +4,7 @@ import numpy as np import tensorflow as tf -def model_fn(features, labels, mode, params, config): - """The model function for join face and patch PAD. The input to the model - is 160x160 faces.""" - - faces = features['data'] - key = features['key'] - +def architecture(faces, mode, **kwargs): # construct patches inside the model ksizes = strides = [1, 28, 28, 1] rates = [1, 1, 1, 1] @@ -19,18 +13,12 @@ def model_fn(features, labels, mode, params, config): # n_blocks should be 25 for 160x160 faces patches = tf.reshape(patches, [-1, n_blocks, 28, 28, 3]) - # organize the parameters - params = params or {} - learning_rate = params.get('learning_rate', 1e-4) - apply_moving_averages = params.get('apply_moving_averages', True) - n_classes = params.get('n_classes', 2) - add_histograms = params.get('add_histograms') - simplecnn_kwargs = { 'kernerl_size': (3, 3), 'data_format': 'channels_last', 'add_batch_norm': True, } + endpoints = {} # construct simplecnn from patches for i in range(n_blocks): if i == 0: @@ -38,19 +26,43 @@ def model_fn(features, labels, mode, params, config): else: reuse = True with tf.variable_scope('SimpleCNN', reuse=reuse): - net, _ = simplecnn_arch(patches[:, i], mode, **simplecnn_kwargs) + net, temp = simplecnn_arch(patches[:, i], mode, **simplecnn_kwargs) if i == 0: simplecnn_embeddings = net + endpoints.update(temp) else: simplecnn_embeddings += net # average the embeddings of patches simplecnn_embeddings /= n_blocks # construct inception_resnet_v2 from faces - incresv2_embeddings, _ = inception_resnet_v2_batch_norm(faces, mode=mode) + incresv2_embeddings, temp = inception_resnet_v2_batch_norm( + faces, mode=mode) + endpoints.update(temp) embeddings = tf.concat([simplecnn_embeddings, incresv2_embeddings], 1) + endpoints['final_embeddings'] = embeddings + + return embeddings, endpoints + + +def model_fn(features, labels, mode, params, config): + """The model function for join face and patch PAD. The input to the model + is 160x160 faces.""" + + faces = features['data'] + key = features['key'] + + # organize the parameters + params = params or {} + learning_rate = params.get('learning_rate', 1e-4) + apply_moving_averages = params.get('apply_moving_averages', True) + n_classes = params.get('n_classes', 2) + add_histograms = params.get('add_histograms') + + embeddings, _ = architecture(faces, mode) + # Logits layer logits = tf.layers.dense(inputs=embeddings, units=n_classes, name='logits') diff --git a/bob/learn/tensorflow/network/SimpleCNN.py b/bob/learn/tensorflow/network/SimpleCNN.py index c44ccfbdec762ca9beaba493b6800d8ed71e9a72..2a0e15d81ee3ee0968bbc4f3bfd91c96fa62be03 100644 --- a/bob/learn/tensorflow/network/SimpleCNN.py +++ b/bob/learn/tensorflow/network/SimpleCNN.py @@ -184,10 +184,11 @@ def model_fn(features, labels, mode, params=None, config=None): params = params or {} learning_rate = params.get('learning_rate', 1e-5) apply_moving_averages = params.get('apply_moving_averages', False) - extra_checkpoint = params.get('extra_checkpoint', None) + extra_checkpoint = params.get('extra_checkpoint') trainable_variables = get_trainable_variables(extra_checkpoint) loss_weights = params.get('loss_weights', 1.0) - add_histograms = params.get('add_histograms', None) + add_histograms = params.get('add_histograms') + nnet_optimizer = params.get('nnet_optimizer') or 'sgd' arch_kwargs = { 'kernerl_size': params.get('kernerl_size', None), @@ -260,8 +261,12 @@ def model_fn(features, labels, mode, params=None, config=None): if mode == tf.estimator.ModeKeys.TRAIN: - optimizer = tf.train.GradientDescentOptimizer( - learning_rate=learning_rate) + if nnet_optimizer == 'sgd': + optimizer = tf.train.GradientDescentOptimizer( + learning_rate=learning_rate) + else: + optimizer = tf.train.AdamOptimizer( + learning_rate=learning_rate) train_op = tf.group( optimizer.minimize(loss, global_step=global_step), variable_averages_op, loss_averages_op) diff --git a/bob/learn/tensorflow/test/test_estimator_onegraph.py b/bob/learn/tensorflow/test/test_estimator_onegraph.py index 5cc989d931b1c5fa29a5cf82cddb0a17556dda47..5b848ba138013292b7ca139be171e3260502fa6c 100644 --- a/bob/learn/tensorflow/test/test_estimator_onegraph.py +++ b/bob/learn/tensorflow/test/test_estimator_onegraph.py @@ -38,6 +38,7 @@ def test_logitstrainer(): # Trainer logits try: embedding_validation = False + _, run_config,_,_,_ = reproducible.set_seed() trainer = Logits( model_dir=model_dir, architecture=dummy, @@ -45,7 +46,8 @@ def test_logitstrainer(): n_classes=10, loss_op=mean_cross_entropy_loss, embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size) + validation_batch_size=validation_batch_size, + config=run_config) run_logitstrainer_mnist(trainer, augmentation=True) finally: try: @@ -59,6 +61,7 @@ def test_logitstrainer(): def test_logitstrainer_embedding(): try: embedding_validation = True + _, run_config,_,_,_ = reproducible.set_seed() trainer = Logits( model_dir=model_dir, architecture=dummy, @@ -66,8 +69,9 @@ def test_logitstrainer_embedding(): n_classes=10, loss_op=mean_cross_entropy_loss, embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size) - + validation_batch_size=validation_batch_size, + config=run_config) + run_logitstrainer_mnist(trainer) finally: try: @@ -81,7 +85,7 @@ def test_logitstrainer_embedding(): def test_logitstrainer_centerloss(): try: embedding_validation = False - run_config = tf.estimator.RunConfig() + _, run_config,_,_,_ = reproducible.set_seed() run_config = run_config.replace(save_checkpoints_steps=1000) trainer = LogitsCenterLoss( model_dir=model_dir, @@ -118,6 +122,7 @@ def test_logitstrainer_centerloss(): def test_logitstrainer_centerloss_embedding(): try: embedding_validation = True + _, run_config,_,_,_ = reproducible.set_seed() trainer = LogitsCenterLoss( model_dir=model_dir, architecture=dummy, @@ -125,7 +130,9 @@ def test_logitstrainer_centerloss_embedding(): n_classes=10, embedding_validation=embedding_validation, validation_batch_size=validation_batch_size, - factor=0.01) + factor=0.01, + config=run_config + ) run_logitstrainer_mnist(trainer) # Checking if the centers were updated @@ -170,7 +177,7 @@ def run_logitstrainer_mnist(trainer, augmentation=False): data_type, batch_size, random_flip=True, - random_rotate=True, + random_rotate=False, epochs=epochs) else: return shuffle_data_and_labels( @@ -196,7 +203,6 @@ def run_logitstrainer_mnist(trainer, augmentation=False): scaffold=tf.train.Scaffold(), summary_writer=tf.summary.FileWriter(model_dir)) ] - trainer.train(input_fn, steps=steps, hooks=hooks) if not trainer.embedding_validation: acc = trainer.evaluate(input_fn_validation) diff --git a/bob/learn/tensorflow/test/test_estimator_transfer.py b/bob/learn/tensorflow/test/test_estimator_transfer.py index 3f33bdb469978007b2190c6ea25f958569831dbe..576fb4e3638a558d644e6a192c0a3550c26610e3 100644 --- a/bob/learn/tensorflow/test/test_estimator_transfer.py +++ b/bob/learn/tensorflow/test/test_estimator_transfer.py @@ -81,6 +81,7 @@ def dummy_adapted(inputs, def test_logitstrainer(): # Trainer logits try: + _, run_config,_,_,_ = reproducible.set_seed() embedding_validation = False trainer = Logits( model_dir=model_dir, @@ -89,7 +90,9 @@ def test_logitstrainer(): n_classes=10, loss_op=mean_cross_entropy_loss, embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size) + validation_batch_size=validation_batch_size, + config=run_config + ) run_logitstrainer_mnist(trainer, augmentation=True) del trainer @@ -110,7 +113,9 @@ def test_logitstrainer(): loss_op=mean_cross_entropy_loss, embedding_validation=embedding_validation, validation_batch_size=validation_batch_size, - extra_checkpoint=extra_checkpoint) + extra_checkpoint=extra_checkpoint, + config=run_config + ) run_logitstrainer_mnist(trainer, augmentation=True) @@ -129,7 +134,7 @@ def test_logitstrainer_center_loss(): # Trainer logits try: embedding_validation = False - + _, run_config,_,_,_ = reproducible.set_seed() trainer = LogitsCenterLoss( model_dir=model_dir, architecture=dummy, @@ -137,7 +142,9 @@ def test_logitstrainer_center_loss(): n_classes=10, embedding_validation=embedding_validation, validation_batch_size=validation_batch_size, - apply_moving_averages=False) + apply_moving_averages=False, + config=run_config + ) run_logitstrainer_mnist(trainer, augmentation=True) del trainer @@ -158,7 +165,9 @@ def test_logitstrainer_center_loss(): embedding_validation=embedding_validation, validation_batch_size=validation_batch_size, extra_checkpoint=extra_checkpoint, - apply_moving_averages=False) + apply_moving_averages=False, + config=run_config + ) run_logitstrainer_mnist(trainer, augmentation=True) diff --git a/bob/learn/tensorflow/utils/util.py b/bob/learn/tensorflow/utils/util.py index 0750d663677f369a92cc3db54ea859c03bdb60aa..ec3cbc9e051947a33feb3410a27281f0e00400e0 100644 --- a/bob/learn/tensorflow/utils/util.py +++ b/bob/learn/tensorflow/utils/util.py @@ -19,7 +19,7 @@ def compute_euclidean_distance(x, y): def load_mnist(perc_train=0.9): - + numpy.random.seed(0) import bob.db.mnist db = bob.db.mnist.Database() raw_data = db.data()