From 33d4e3b6e07e9bc4755ff09f5ea7ff12b1d28fc8 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira Date: Fri, 24 Nov 2017 15:47:11 +0100 Subject: [PATCH] Created mechanism that allows as to train only parts of the graph Fixed the contrastive loss equation in the comments Created mechanism that shuts down parts of the graph with respect to backpropagation Fixed small bug Reverted from non_trainable_variables to trainable_variables Pylint the code pylinting Reorganized transfer learning mechanism --- bob/learn/tensorflow/estimators/Logits.py | 51 ++--- bob/learn/tensorflow/estimators/Siamese.py | 76 +++---- bob/learn/tensorflow/estimators/Triplet.py | 53 +++-- bob/learn/tensorflow/estimators/__init__.py | 29 ++- bob/learn/tensorflow/loss/ContrastiveLoss.py | 9 +- bob/learn/tensorflow/network/Dummy.py | 29 ++- .../tensorflow/network/InceptionResnetV1.py | 24 ++- .../tensorflow/network/InceptionResnetV2.py | 193 +++++++++++------- bob/learn/tensorflow/network/LightCNN9.py | 86 +++++--- bob/learn/tensorflow/network/MLP.py | 2 +- bob/learn/tensorflow/network/utils.py | 23 +++ bob/learn/tensorflow/script/train_generic.py | 1 + bob/learn/tensorflow/test/test_dataset.py | 9 +- .../test/test_estimator_onegraph.py | 91 ++++----- .../tensorflow/test/test_estimator_siamese.py | 93 ++++++--- .../test/test_estimator_transfer.py | 112 +++++----- .../tensorflow/test/test_estimator_triplet.py | 42 ++-- .../tensorflow/test/test_image_dataset.py | 41 ++-- 18 files changed, 546 insertions(+), 418 deletions(-) diff --git a/bob/learn/tensorflow/estimators/Logits.py b/bob/learn/tensorflow/estimators/Logits.py index b9c19f8..6a7b714 100755 --- a/bob/learn/tensorflow/estimators/Logits.py +++ b/bob/learn/tensorflow/estimators/Logits.py @@ -8,9 +8,10 @@ from tensorflow.python.estimator import estimator from bob.learn.tensorflow.utils import predict_using_tensors from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss -from . import check_features, is_trainable_checkpoint +from . import check_features, get_trainable_variables import logging + logger = logging.getLogger("bob.learn") @@ -77,7 +78,7 @@ class Logits(estimator.Estimator): extra_checkpoint = { "checkpoint_path": , "scopes": dict({"/": "/"}), - "is_trainable": + "trainable_variables": [] } """ @@ -111,14 +112,11 @@ class Logits(estimator.Estimator): # Configure the Training Op (for TRAIN mode) if mode == tf.estimator.ModeKeys.TRAIN: - # Building one graph, by default everything is trainable - if self.extra_checkpoint is None: - is_trainable = True - else: - is_trainable = is_trainable_checkpoint(self.extra_checkpoint) + # Building the training graph - # Building the training graph - prelogits = self.architecture(data, mode=mode, trainable_variables=is_trainable)[0] + # Checking if we have some variables/scope that we may want to shut down + trainable_variables = get_trainable_variables(self.extra_checkpoint) + prelogits = self.architecture(data, mode=mode, trainable_variables=trainable_variables)[0] logits = append_logits(prelogits, n_classes) # Compute Loss (for both TRAIN and EVAL modes) @@ -132,10 +130,9 @@ class Logits(estimator.Estimator): train_op = self.optimizer.minimize(self.loss, global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss, train_op=train_op) - # Building the training graph for PREDICTION OR VALIDATION - prelogits = self.architecture(data, mode=mode, trainable_variables=False)[0] + prelogits = self.architecture(data, mode=mode)[0] logits = append_logits(prelogits, n_classes) if self.embedding_validation: @@ -162,7 +159,7 @@ class Logits(estimator.Estimator): # IF Validation self.loss = self.loss_op(logits, labels) - + if self.embedding_validation: predictions_op = predict_using_tensors( predictions["embeddings"], labels, @@ -233,6 +230,15 @@ class LogitsCenterLoss(estimator.Estimator): params: Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info) + + extra_checkpoint: dict + In case you want to use other model to initialize some variables. + This argument should be in the following format + extra_checkpoint = { + "checkpoint_path": , + "scopes": dict({"/": "/"}), + "trainable_variables": [] + } """ @@ -279,15 +285,11 @@ class LogitsCenterLoss(estimator.Estimator): # Configure the Training Op (for TRAIN mode) if mode == tf.estimator.ModeKeys.TRAIN: + # Building the training graph - # Building one graph, by default everything is trainable - if self.extra_checkpoint is None: - is_trainable = True - else: - is_trainable = is_trainable_checkpoint(self.extra_checkpoint) - - # Building the training graph - prelogits = self.architecture(data, mode=mode, trainable_variables=is_trainable)[0] + # Checking if we have some variables/scope that we may want to shut down + trainable_variables = get_trainable_variables(self.extra_checkpoint) + prelogits = self.architecture(data, mode=mode, trainable_variables=trainable_variables)[0] logits = append_logits(prelogits, n_classes) # Compute Loss (for TRAIN mode) @@ -308,7 +310,7 @@ class LogitsCenterLoss(estimator.Estimator): train_op=train_op) # Building the training graph for PREDICTION OR VALIDATION - prelogits = self.architecture(data, mode=mode, trainable_variables=False)[0] + prelogits = self.architecture(data, mode=mode)[0] logits = append_logits(prelogits, n_classes) if self.embedding_validation: @@ -334,15 +336,15 @@ class LogitsCenterLoss(estimator.Estimator): # IF Validation loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes, alpha=self.alpha, factor=self.factor) - self.loss = loss_dict['loss'] - + self.loss = loss_dict['loss'] + if self.embedding_validation: predictions_op = predict_using_tensors( predictions["embeddings"], labels, num=validation_batch_size) eval_metric_ops = {"accuracy": tf.metrics.accuracy( labels=labels, predictions=predictions_op)} return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops) - + else: # Add evaluation metrics (for EVAL mode) eval_metric_ops = { @@ -354,4 +356,3 @@ class LogitsCenterLoss(estimator.Estimator): super(LogitsCenterLoss, self).__init__(model_fn=_model_fn, model_dir=model_dir, config=config) - diff --git a/bob/learn/tensorflow/estimators/Siamese.py b/bob/learn/tensorflow/estimators/Siamese.py index 5259b3c..5b1b6be 100755 --- a/bob/learn/tensorflow/estimators/Siamese.py +++ b/bob/learn/tensorflow/estimators/Siamese.py @@ -3,24 +3,13 @@ # @author: Tiago de Freitas Pereira import tensorflow as tf -import os -import bob.io.base -import bob.core -from tensorflow.core.framework import summary_pb2 -import time - -#logger = bob.core.log.setup("bob.learn.tensorflow") from tensorflow.python.estimator import estimator from bob.learn.tensorflow.utils import predict_using_tensors -#from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss -from . import check_features, is_trainable_checkpoint +from . import check_features, get_trainable_variables import logging -logger = logging.getLogger("bob.learn") - -from bob.learn.tensorflow.network.utils import append_logits -from bob.learn.tensorflow.loss import mean_cross_entropy_loss +logger = logging.getLogger("bob.learn") class Siamese(estimator.Estimator): @@ -45,7 +34,7 @@ class Siamese(estimator.Estimator): extra_checkpoint = {"checkpoint_path":model_dir, "scopes": dict({"Dummy/": "Dummy/"}), - "is_trainable": False + "trainable_variables": [] } @@ -81,13 +70,14 @@ class Siamese(estimator.Estimator): Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info) - extra_checkpoint: dict() + extra_checkpoint: dict In case you want to use other model to initialize some variables. This argument should be in the following format - extra_checkpoint = {"checkpoint_path": , - "scopes": dict({"/": "/"}), - "is_trainable": - } + extra_checkpoint = { + "checkpoint_path": , + "scopes": dict({"/": "/"}), + "trainable_variables": [] + } """ @@ -99,18 +89,18 @@ class Siamese(estimator.Estimator): model_dir="", validation_batch_size=None, params=None, - extra_checkpoint=None - ): + extra_checkpoint=None + ): self.architecture = architecture - self.optimizer=optimizer - self.loss_op=loss_op + self.optimizer = optimizer + self.loss_op = loss_op self.loss = None - self.extra_checkpoint = extra_checkpoint + self.extra_checkpoint = extra_checkpoint if self.architecture is None: raise ValueError("Please specify a function to build the architecture !!") - + if self.optimizer is None: raise ValueError("Please specify a optimizer (https://www.tensorflow.org/api_guides/python/train) !!") @@ -118,30 +108,28 @@ class Siamese(estimator.Estimator): raise ValueError("Please specify a function to build the loss !!") def _model_fn(features, labels, mode, params, config): + if mode == tf.estimator.ModeKeys.TRAIN: - if mode == tf.estimator.ModeKeys.TRAIN: - # Building one graph, by default everything is trainable - if self.extra_checkpoint is None: - is_trainable = True - else: - is_trainable = is_trainable_checkpoint(self.extra_checkpoint) - # The input function needs to have dictionary pair with the `left` and `right` keys - if not 'left' in features.keys() or not 'right' in features.keys(): - raise ValueError("The input function needs to contain a dictionary with the keys `left` and `right` ") + if 'left' not in features.keys() or 'right' not in features.keys(): + raise ValueError( + "The input function needs to contain a dictionary with the keys `left` and `right` ") # Building one graph - prelogits_left, end_points_left = self.architecture(features['left'], mode=mode, trainable_variables=is_trainable) - prelogits_right, end_points_right = self.architecture(features['right'], reuse=True, mode=mode, trainable_variables=is_trainable) + trainable_variables = get_trainable_variables(self.extra_checkpoint) + prelogits_left, end_points_left = self.architecture(features['left'], mode=mode, + trainable_variables=trainable_variables) + prelogits_right, end_points_right = self.architecture(features['right'], reuse=True, mode=mode, + trainable_variables=trainable_variables) if self.extra_checkpoint is not None: tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"], self.extra_checkpoint["scopes"]) - + # Compute Loss (for both TRAIN and EVAL modes) self.loss = self.loss_op(prelogits_left, prelogits_right, labels) - + # Configure the Training Op (for TRAIN mode) global_step = tf.contrib.framework.get_or_create_global_step() train_op = self.optimizer.minimize(self.loss, global_step=global_step) @@ -153,7 +141,7 @@ class Siamese(estimator.Estimator): data = features['data'] # Compute the embeddings - prelogits = self.architecture(data, mode=mode, trainable_variables=False)[0] + prelogits = self.architecture(data, mode=mode)[0] embeddings = tf.nn.l2_normalize(prelogits, 1) predictions = {"embeddings": embeddings} @@ -162,12 +150,10 @@ class Siamese(estimator.Estimator): predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size) eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)} - + return tf.estimator.EstimatorSpec(mode=mode, loss=tf.reduce_mean(1), eval_metric_ops=eval_metric_ops) - super(Siamese, self).__init__(model_fn=_model_fn, - model_dir=model_dir, - params=params, - config=config) - + model_dir=model_dir, + params=params, + config=config) diff --git a/bob/learn/tensorflow/estimators/Triplet.py b/bob/learn/tensorflow/estimators/Triplet.py index 1981e3f..3b369e3 100755 --- a/bob/learn/tensorflow/estimators/Triplet.py +++ b/bob/learn/tensorflow/estimators/Triplet.py @@ -3,19 +3,13 @@ # @author: Tiago de Freitas Pereira import tensorflow as tf -import os -import bob.io.base -import bob.core -from tensorflow.core.framework import summary_pb2 -import time - -#logger = bob.core.log.setup("bob.learn.tensorflow") from tensorflow.python.estimator import estimator from bob.learn.tensorflow.utils import predict_using_tensors from bob.learn.tensorflow.loss import triplet_loss -from . import check_features, is_trainable_checkpoint +from . import check_features, get_trainable_variables import logging + logger = logging.getLogger("bob.learn") @@ -70,13 +64,14 @@ class Triplet(estimator.Estimator): Size of the batch for validation. This value is used when the validation with embeddings is used. This is a hack. - extra_checkpoint: dict() + extra_checkpoint: dict In case you want to use other model to initialize some variables. This argument should be in the following format - extra_checkpoint = {"checkpoint_path": , - "scopes": dict({"/": "/"}), - "is_trainable": - } + extra_checkpoint = { + "checkpoint_path": , + "scopes": dict({"/": "/"}), + "trainable_variables": [] + } """ def __init__(self, @@ -87,17 +82,17 @@ class Triplet(estimator.Estimator): model_dir="", validation_batch_size=None, extra_checkpoint=None - ): + ): self.architecture = architecture - self.optimizer=optimizer - self.loss_op=loss_op + self.optimizer = optimizer + self.loss_op = loss_op self.loss = None self.extra_checkpoint = extra_checkpoint if self.architecture is None: raise ValueError("Please specify a function to build the architecture !!") - + if self.optimizer is None: raise ValueError("Please specify a optimizer (https://www.tensorflow.org/api_guides/python/train) !!") @@ -109,21 +104,20 @@ class Triplet(estimator.Estimator): if mode == tf.estimator.ModeKeys.TRAIN: # The input function needs to have dictionary pair with the `left` and `right` keys - if not 'anchor' in features.keys() or not \ - 'positive' in features.keys() or not \ - 'negative' in features.keys(): + if 'anchor' not in features.keys() or \ + 'positive' not in features.keys() or \ + 'negative' not in features.keys(): raise ValueError("The input function needs to contain a dictionary with the " "keys `anchor`, `positive` and `negative` ") - if self.extra_checkpoint is None: - is_trainable = True - else: - is_trainable = is_trainable_checkpoint(self.extra_checkpoint) - # Building one graph - prelogits_anchor = self.architecture(features['anchor'], mode=mode)[0] - prelogits_positive = self.architecture(features['positive'], reuse=True, mode=mode)[0] - prelogits_negative = self.architecture(features['negative'], reuse=True, mode=mode)[0] + trainable_variables = get_trainable_variables(self.extra_checkpoint) + prelogits_anchor = self.architecture(features['anchor'], mode=mode, + trainable_variables=trainable_variables)[0] + prelogits_positive = self.architecture(features['positive'], reuse=True, mode=mode, + trainable_variables=trainable_variables)[0] + prelogits_negative = self.architecture(features['negative'], reuse=True, mode=mode, + trainable_variables=trainable_variables)[0] if self.extra_checkpoint is not None: tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"], @@ -150,10 +144,9 @@ class Triplet(estimator.Estimator): predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size) eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)} - + return tf.estimator.EstimatorSpec(mode=mode, loss=tf.reduce_mean(1), eval_metric_ops=eval_metric_ops) super(Triplet, self).__init__(model_fn=_model_fn, model_dir=model_dir, config=config) - diff --git a/bob/learn/tensorflow/estimators/__init__.py b/bob/learn/tensorflow/estimators/__init__.py index 35dd9ca..d70ad87 100755 --- a/bob/learn/tensorflow/estimators/__init__.py +++ b/bob/learn/tensorflow/estimators/__init__.py @@ -4,18 +4,39 @@ import tensorflow as tf + def check_features(features): if not 'data' in features.keys() or not 'key' in features.keys(): raise ValueError("The input function needs to contain a dictionary with the keys `data` and `key` ") return True -def is_trainable_checkpoint(params): +def get_trainable_variables(extra_checkpoint): + """ + Given the extra_checkpoint dictionary provided to the estimator, + extract the content of "trainable_variables" e. + + If trainable_variables is not provided, all end points are trainable by default. + If trainable_variables==[], all end points are NOT trainable. + If trainable_variables contains some end_points, ONLY these endpoints will be trainable. + + Parameters + ---------- + extra_checkpoint: dict + The `extra_checkpoint dictionary provided to the estimator + + Returns + ------- + Returns `None` if `trainable_variables` is not in extra_checkpoint; + otherwise returns the content of `extra_checkpoint + + """ - if not "is_trainable" in params: - raise ValueError("Param `is_trainable` is missing in `load_variable_from_checkpoint` dictionary") + # If you don't set anything, everything is trainable + if extra_checkpoint is None or "trainable_variables" not in extra_checkpoint: + return None - return params["is_trainable"] + return extra_checkpoint["trainable_variables"] from .Logits import Logits, LogitsCenterLoss diff --git a/bob/learn/tensorflow/loss/ContrastiveLoss.py b/bob/learn/tensorflow/loss/ContrastiveLoss.py index 62f93da..4dbc250 100755 --- a/bob/learn/tensorflow/loss/ContrastiveLoss.py +++ b/bob/learn/tensorflow/loss/ContrastiveLoss.py @@ -9,13 +9,16 @@ import tensorflow as tf from bob.learn.tensorflow.utils import compute_euclidean_distance -def contrastive_loss(left_embedding, right_embedding, labels, contrastive_margin=1.0): +def contrastive_loss(left_embedding, right_embedding, labels, contrastive_margin=2.0): """ Compute the contrastive loss as in http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf - :math:`L = 0.5 * (Y) * D^2 + 0.5 * (1-Y) * {max(0, margin - D)}^2` + :math:`L = 0.5 * (1-Y) * D^2 + 0.5 * (Y) * {max(0, margin - D)}^2` + + where, `0` are assign for pairs from the same class and `1` from pairs from different classes. + **Parameters** @@ -68,7 +71,7 @@ def contrastive_loss_deprecated(left_embedding, right_embedding, labels, contras http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf - :math:`L = 0.5 * (Y) * D^2 + 0.5 * (1-Y) * {max(0, margin - D)}^2` + :math:`L = 0.5 * (1-Y) * D^2 + 0.5 * (Y) * {max(0, margin - D)}^2` **Parameters** diff --git a/bob/learn/tensorflow/network/Dummy.py b/bob/learn/tensorflow/network/Dummy.py index 9ba72e8..374f2ae 100755 --- a/bob/learn/tensorflow/network/Dummy.py +++ b/bob/learn/tensorflow/network/Dummy.py @@ -3,29 +3,38 @@ # @author: Tiago de Freitas Pereira import tensorflow as tf +from .utils import is_trainable -def dummy(inputs, reuse=False, mode = tf.estimator.ModeKeys.TRAIN, trainable_variables=True): +def dummy(inputs, reuse=False, mode=tf.estimator.ModeKeys.TRAIN, trainable_variables=None, **kwargs): """ Create all the necessary variables for this CNN - **Parameters** + Parameters + ---------- inputs: reuse: + + mode: + + trainable_variables: + """ slim = tf.contrib.slim end_points = dict() + # Here is my choice to shutdown the whole scope + trainable = is_trainable("Dummy", trainable_variables) with tf.variable_scope('Dummy', reuse=reuse): initializer = tf.contrib.layers.xavier_initializer() - - graph = slim.conv2d(inputs, 10, [3, 3], activation_fn=tf.nn.relu, stride=1, scope='conv1', + name = 'conv1' + graph = slim.conv2d(inputs, 10, [3, 3], activation_fn=tf.nn.relu, stride=1, scope=name, weights_initializer=initializer, - trainable=trainable_variables) - end_points['conv1'] = graph + trainable=trainable) + end_points[name] = graph graph = slim.max_pool2d(graph, [4, 4], scope='pool1') end_points['pool1'] = graph @@ -33,13 +42,13 @@ def dummy(inputs, reuse=False, mode = tf.estimator.ModeKeys.TRAIN, trainable_var graph = slim.flatten(graph, scope='flatten1') end_points['flatten1'] = graph + name = 'fc1' graph = slim.fully_connected(graph, 50, weights_initializer=initializer, activation_fn=None, - scope='fc1', - trainable=trainable_variables) - end_points['fc1'] = graph - + scope=name, + trainable=trainable) + end_points[name] = graph return graph, end_points diff --git a/bob/learn/tensorflow/network/InceptionResnetV1.py b/bob/learn/tensorflow/network/InceptionResnetV1.py index 9c0133b..f5b2382 100755 --- a/bob/learn/tensorflow/network/InceptionResnetV1.py +++ b/bob/learn/tensorflow/network/InceptionResnetV1.py @@ -26,6 +26,7 @@ from __future__ import print_function import tensorflow as tf import tensorflow.contrib.slim as slim + # Inception-Renset-A def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None, trainable_variables=True): """Builds the 35x35 resnet block.""" @@ -47,6 +48,7 @@ def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None, tr net = activation_fn(net) return net + # Inception-Renset-B def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None, trainable_variables=True): """Builds the 17x17 resnet block.""" @@ -87,7 +89,8 @@ def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None, tra if activation_fn: net = activation_fn(net) return net - + + def reduction_a(net, k, l, m, n, trainable_variables=True, reuse=None): with tf.variable_scope('Branch_0', reuse=reuse): tower_conv = slim.conv2d(net, n, 3, stride=2, padding='VALID', @@ -105,6 +108,7 @@ def reduction_a(net, k, l, m, n, trainable_variables=True, reuse=None): net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) return net + def reduction_b(net, reuse=None, trainable_variables=True): with tf.variable_scope('Branch_0', reuse=reuse): tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables) @@ -124,9 +128,10 @@ def reduction_b(net, reuse=None, trainable_variables=True): tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', scope='MaxPool_1a_3x3') net = tf.concat([tower_conv_1, tower_conv1_1, - tower_conv2_2, tower_pool], 3) + tower_conv2_2, tower_pool], 3) return net - + + def inference(images, keep_probability, phase_train=True, bottleneck_layer_size=128, weight_decay=0.0, reuse=None): batch_norm_params = { @@ -146,16 +151,18 @@ def inference(images, keep_probability, phase_train=True, normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params): return inception_resnet_v1(images, is_training=phase_train, - dropout_keep_prob=keep_probability, bottleneck_layer_size=bottleneck_layer_size, reuse=reuse) + dropout_keep_prob=keep_probability, + bottleneck_layer_size=bottleneck_layer_size, + reuse=reuse) -def inception_resnet_v1(inputs, is_training=True, +def inception_resnet_v1(inputs, dropout_keep_prob=0.8, bottleneck_layer_size=128, reuse=None, scope='InceptionResnetV1', - mode = tf.estimator.ModeKeys.TRAIN, - trainable_variables=True): + mode=tf.estimator.ModeKeys.TRAIN, + trainable_variables=True, **kwargs): """ Creates the Inception Resnet V1 model. @@ -248,6 +255,7 @@ def inception_resnet_v1(inputs, is_training=True, end_points['PreLogitsFlatten'] = net net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None, - scope='Bottleneck', reuse=reuse, trainable=trainable_variables) + scope='Bottleneck', + reuse=reuse, trainable=trainable_variables) return net, end_points diff --git a/bob/learn/tensorflow/network/InceptionResnetV2.py b/bob/learn/tensorflow/network/InceptionResnetV2.py index 5ff38f9..b2427a0 100755 --- a/bob/learn/tensorflow/network/InceptionResnetV2.py +++ b/bob/learn/tensorflow/network/InceptionResnetV2.py @@ -25,6 +25,8 @@ from __future__ import print_function import tensorflow as tf import tensorflow.contrib.slim as slim +from .utils import is_trainable + # Inception-Renset-A def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None, trainable_variables=True): @@ -34,11 +36,14 @@ def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None, tr tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1', trainable=trainable_variables, reuse=reuse) with tf.variable_scope('Branch_1'): tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse) - tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3', trainable=trainable_variables, reuse=reuse) + tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3', trainable=trainable_variables, + reuse=reuse) with tf.variable_scope('Branch_2'): tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse) - tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3', trainable=trainable_variables, reuse=reuse) - tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3', trainable=trainable_variables, reuse=reuse) + tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3', trainable=trainable_variables, + reuse=reuse) + tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3', trainable=trainable_variables, + reuse=reuse) mixed = tf.concat([tower_conv, tower_conv1_1, tower_conv2_2], 3) up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, activation_fn=None, scope='Conv2d_1x1', trainable=trainable_variables, reuse=reuse) @@ -47,6 +52,7 @@ def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None, tr net = activation_fn(net) return net + # Inception-Renset-B def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None, trainable_variables=True): """Builds the 17x17 resnet block.""" @@ -87,8 +93,9 @@ def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None, tra if activation_fn: net = activation_fn(net) return net - -def inference(images, keep_probability, phase_train=True, + + +def inference(images, keep_probability, bottleneck_layer_size=128, weight_decay=0.0, reuse=None): batch_norm_params = { # Decay for the moving averages. @@ -98,24 +105,26 @@ def inference(images, keep_probability, phase_train=True, # force in-place updates of mean and variance estimates 'updates_collections': None, # Moving averages ends up in the trainable variables collection - 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ], -} + 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES], + } with slim.arg_scope([slim.conv2d, slim.fully_connected], weights_initializer=tf.truncated_normal_initializer(stddev=0.1), weights_regularizer=slim.l2_regularizer(weight_decay), normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params): - return inception_resnet_v2(images, mode = tf.estimator.ModeKeys.PREDICT, - dropout_keep_prob=keep_probability, bottleneck_layer_size=bottleneck_layer_size, reuse=reuse) + return inception_resnet_v2(images, mode=tf.estimator.ModeKeys.PREDICT, + dropout_keep_prob=keep_probability, + bottleneck_layer_size=bottleneck_layer_size, reuse=reuse) -def inception_resnet_v2(inputs, +def inception_resnet_v2(inputs, dropout_keep_prob=0.8, bottleneck_layer_size=128, reuse=None, scope='InceptionResnetV2', - mode = tf.estimator.ModeKeys.TRAIN, - trainable_variables=True): + mode=tf.estimator.ModeKeys.TRAIN, + trainable_variables=None, + **kwargs): """Creates the Inception Resnet V2 model. **Parameters**: @@ -132,6 +141,9 @@ def inception_resnet_v2(inputs, able to reuse 'scope' must be given. scope: Optional variable_scope. + + trainable_variables: list + List of variables to be trainable=True **Returns**: @@ -144,124 +156,161 @@ def inception_resnet_v2(inputs, is_training=(mode == tf.estimator.ModeKeys.TRAIN)): with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], stride=1, padding='SAME'): - # 149 x 149 x 32 + name = "Conv2d_1a_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID', - scope='Conv2d_1a_3x3', trainable=trainable_variables, reuse=reuse) - end_points['Conv2d_1a_3x3'] = net + scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net + # 147 x 147 x 32 + name = "Conv2d_2a_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d(net, 32, 3, padding='VALID', - scope='Conv2d_2a_3x3', trainable=trainable_variables, reuse=reuse) - end_points['Conv2d_2a_3x3'] = net + scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net + # 147 x 147 x 64 - net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3', trainable=trainable_variables, reuse=reuse) - end_points['Conv2d_2b_3x3'] = net + name = "Conv2d_2b_3x3" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d(net, 64, 3, scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net + # 73 x 73 x 64 net = slim.max_pool2d(net, 3, stride=2, padding='VALID', scope='MaxPool_3a_3x3') end_points['MaxPool_3a_3x3'] = net + # 73 x 73 x 80 + name = "Conv2d_3b_1x1" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d(net, 80, 1, padding='VALID', - scope='Conv2d_3b_1x1', trainable=trainable_variables, reuse=reuse) - end_points['Conv2d_3b_1x1'] = net + scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net + # 71 x 71 x 192 + name = "Conv2d_4a_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d(net, 192, 3, padding='VALID', - scope='Conv2d_4a_3x3', trainable=trainable_variables, reuse=reuse) - end_points['Conv2d_4a_3x3'] = net + scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net + # 35 x 35 x 192 net = slim.max_pool2d(net, 3, stride=2, padding='VALID', scope='MaxPool_5a_3x3') end_points['MaxPool_5a_3x3'] = net - + # 35 x 35 x 320 - with tf.variable_scope('Mixed_5b'): + name = "Mixed_5b" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): with tf.variable_scope('Branch_0'): - tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1', trainable=trainable_variables, reuse=reuse) + tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1', trainable=trainable, reuse=reuse) with tf.variable_scope('Branch_1'): - tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse) + tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1', trainable=trainable, reuse=reuse) tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5, - scope='Conv2d_0b_5x5', trainable=trainable_variables, reuse=reuse) + scope='Conv2d_0b_5x5', trainable=trainable, reuse=reuse) with tf.variable_scope('Branch_2'): - tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse) + tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1', trainable=trainable, reuse=reuse) tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3, - scope='Conv2d_0b_3x3', trainable=trainable_variables, reuse=reuse) + scope='Conv2d_0b_3x3', trainable=trainable, reuse=reuse) tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3, - scope='Conv2d_0c_3x3', trainable=trainable_variables, reuse=reuse) + scope='Conv2d_0c_3x3', trainable=trainable, reuse=reuse) with tf.variable_scope('Branch_3'): tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME', scope='AvgPool_0a_3x3') tower_pool_1 = slim.conv2d(tower_pool, 64, 1, - scope='Conv2d_0b_1x1', trainable=trainable_variables, reuse=reuse) + scope='Conv2d_0b_1x1', trainable=trainable, reuse=reuse) net = tf.concat([tower_conv, tower_conv1_1, - tower_conv2_2, tower_pool_1], 3) - - end_points['Mixed_5b'] = net - net = slim.repeat(net, 10, block35, scale=0.17, trainable_variables=trainable_variables, reuse=reuse) - + tower_conv2_2, tower_pool_1], 3) + end_points[name] = net + + # BLOCK 35 + name = "Block35" + trainable = is_trainable(name, trainable_variables) + net = slim.repeat(net, 10, block35, scale=0.17, trainable_variables=trainable, reuse=reuse) + # 17 x 17 x 1024 - with tf.variable_scope('Mixed_6a'): + name = "Mixed_6a" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): with tf.variable_scope('Branch_0'): tower_conv = slim.conv2d(net, 384, 3, stride=2, padding='VALID', - scope='Conv2d_1a_3x3', trainable=trainable_variables, reuse=reuse) + scope='Conv2d_1a_3x3', trainable=trainable, reuse=reuse) with tf.variable_scope('Branch_1'): - tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse) + tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable, + reuse=reuse) tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3, - scope='Conv2d_0b_3x3', trainable=trainable_variables, reuse=reuse) + scope='Conv2d_0b_3x3', trainable=trainable, reuse=reuse) tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3, stride=2, padding='VALID', - scope='Conv2d_1a_3x3', trainable=trainable_variables, reuse=reuse) + scope='Conv2d_1a_3x3', trainable=trainable, reuse=reuse) with tf.variable_scope('Branch_2'): tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', scope='MaxPool_1a_3x3') net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) - - end_points['Mixed_6a'] = net - net = slim.repeat(net, 20, block17, scale=0.10,trainable_variables=trainable_variables, reuse=reuse) - - with tf.variable_scope('Mixed_7a'): + + end_points[name] = net + + # BLOCK 17 + name = "Block17" + trainable = is_trainable(name, trainable_variables) + net = slim.repeat(net, 20, block17, scale=0.10, trainable_variables=trainable, reuse=reuse) + + name = "Mixed_7a" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): with tf.variable_scope('Branch_0'): - tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse) + tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable, reuse=reuse) tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2, - padding='VALID', scope='Conv2d_1a_3x3', trainable=trainable_variables, reuse=reuse) + padding='VALID', scope='Conv2d_1a_3x3', trainable=trainable, + reuse=reuse) with tf.variable_scope('Branch_1'): - tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse) + tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable, reuse=reuse) tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2, - padding='VALID', scope='Conv2d_1a_3x3', trainable=trainable_variables, reuse=reuse) + padding='VALID', scope='Conv2d_1a_3x3', trainable=trainable, + reuse=reuse) with tf.variable_scope('Branch_2'): - tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse) + tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable, reuse=reuse) tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3, - scope='Conv2d_0b_3x3', trainable=trainable_variables, reuse=reuse) + scope='Conv2d_0b_3x3', trainable=trainable, reuse=reuse) tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2, - padding='VALID', scope='Conv2d_1a_3x3', trainable=trainable_variables, reuse=reuse) + padding='VALID', scope='Conv2d_1a_3x3', trainable=trainable, + reuse=reuse) with tf.variable_scope('Branch_3'): tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', scope='MaxPool_1a_3x3') net = tf.concat([tower_conv_1, tower_conv1_1, - tower_conv2_2, tower_pool], 3) - - end_points['Mixed_7a'] = net - - net = slim.repeat(net, 9, block8, scale=0.20,trainable_variables=trainable_variables, reuse=reuse) - net = block8(net, activation_fn=None,trainable_variables=trainable_variables, reuse=reuse) - - net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1', trainable=trainable_variables, reuse=reuse) - end_points['Conv2d_7b_1x1'] = net - + tower_conv2_2, tower_pool], 3) + end_points[name] = net + + # Block 8 + name = "Block8" + trainable = is_trainable(name, trainable_variables) + net = slim.repeat(net, 9, block8, scale=0.20, trainable_variables=trainable, reuse=reuse) + net = block8(net, activation_fn=None, trainable_variables=trainable, reuse=reuse) + + name = "Conv2d_7b_1x1" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d(net, 1536, 1, scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net + with tf.variable_scope('Logits'): end_points['PrePool'] = net - #pylint: disable=no-member + # pylint: disable=no-member net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID', scope='AvgPool_1a_8x8') net = slim.flatten(net) - + net = slim.dropout(net, dropout_keep_prob, scope='Dropout') - + end_points['PreLogitsFlatten'] = net - - net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None, - scope='Bottleneck', reuse=reuse, trainable=trainable_variables) - end_points['Bottleneck'] = net - return net, end_points + name = "Bottleneck" + trainable = is_trainable(name, trainable_variables) + net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None, + scope=name, reuse=reuse, trainable=trainable) + end_points[name] = net + return net, end_points diff --git a/bob/learn/tensorflow/network/LightCNN9.py b/bob/learn/tensorflow/network/LightCNN9.py index 296e5e6..451a5cf 100755 --- a/bob/learn/tensorflow/network/LightCNN9.py +++ b/bob/learn/tensorflow/network/LightCNN9.py @@ -4,9 +4,10 @@ import tensorflow as tf from bob.learn.tensorflow.layers import maxout -from .utils import append_logits +from .utils import is_trainable -def light_cnn9(inputs, seed=10, reuse=False): + +def light_cnn9(inputs, seed=10, reuse=False, trainable_variables=None, **kwargs): """Creates the graph for the Light CNN-9 in Wu, Xiang, et al. "A light CNN for deep face representation with noisy labels." arXiv preprint arXiv:1511.02683 (2015). @@ -14,17 +15,18 @@ def light_cnn9(inputs, seed=10, reuse=False): slim = tf.contrib.slim with tf.variable_scope('LightCNN9', reuse=reuse): - initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=seed) end_points = dict() - + name = "Conv1" + trainable = is_trainable(name, trainable_variables) graph = slim.conv2d(inputs, 96, [5, 5], activation_fn=tf.nn.relu, stride=1, weights_initializer=initializer, - scope='Conv1', + scope=name, + trainable=trainable, reuse=reuse) - end_points['conv1'] = graph - + end_points[name] = graph + graph = maxout(graph, num_units=48, name='Maxout1') @@ -32,24 +34,29 @@ def light_cnn9(inputs, seed=10, reuse=False): graph = slim.max_pool2d(graph, [2, 2], stride=2, padding="SAME", scope='Pool1') #### - + name = "Conv2a" + trainable = is_trainable(name, trainable_variables) graph = slim.conv2d(graph, 96, [1, 1], activation_fn=tf.nn.relu, stride=1, weights_initializer=initializer, - scope='Conv2a', + scope=name, + trainable=trainable, reuse=reuse) graph = maxout(graph, num_units=48, name='Maxout2a') + name = "Conv2" + trainable = is_trainable(name, trainable_variables) graph = slim.conv2d(graph, 192, [3, 3], activation_fn=tf.nn.relu, stride=1, weights_initializer=initializer, - scope='Conv2', + scope=name, + trainable=trainable, reuse=reuse) - end_points['conv2'] = graph - + end_points[name] = graph + graph = maxout(graph, num_units=96, name='Maxout2') @@ -57,23 +64,28 @@ def light_cnn9(inputs, seed=10, reuse=False): graph = slim.max_pool2d(graph, [2, 2], stride=2, padding="SAME", scope='Pool2') ##### - + name = "Conv3a" + trainable = is_trainable(name, trainable_variables) graph = slim.conv2d(graph, 192, [1, 1], activation_fn=tf.nn.relu, stride=1, weights_initializer=initializer, - scope='Conv3a', + scope=name, + trainable=trainable, reuse=reuse) graph = maxout(graph, num_units=96, name='Maxout3a') + name = "Conv3" + trainable = is_trainable(name, trainable_variables) graph = slim.conv2d(graph, 384, [3, 3], activation_fn=tf.nn.relu, stride=1, weights_initializer=initializer, - scope='Conv3', + scope=name, + trainable=trainable, reuse=reuse) - end_points['conv3'] = graph + end_points[name] = graph graph = maxout(graph, num_units=192, @@ -82,64 +94,76 @@ def light_cnn9(inputs, seed=10, reuse=False): graph = slim.max_pool2d(graph, [2, 2], stride=2, padding="SAME", scope='Pool3') ##### - + name = "Conv4a" + trainable = is_trainable(name, trainable_variables) graph = slim.conv2d(graph, 384, [1, 1], activation_fn=tf.nn.relu, stride=1, weights_initializer=initializer, - scope='Conv4a', + scope=name, + trainable=trainable, reuse=reuse) graph = maxout(graph, num_units=192, name='Maxout4a') + name = "Conv4" + trainable = is_trainable(name, trainable_variables) graph = slim.conv2d(graph, 256, [3, 3], activation_fn=tf.nn.relu, stride=1, weights_initializer=initializer, - scope='Conv4', + scope=name, + trainable=trainable, reuse=reuse) - end_points['conv4'] = graph + end_points[name] = graph graph = maxout(graph, num_units=128, name='Maxout4') ##### - + name = "Conv5a" + trainable = is_trainable(name, trainable_variables) graph = slim.conv2d(graph, 256, [1, 1], activation_fn=tf.nn.relu, stride=1, weights_initializer=initializer, - scope='Conv5a', + scope=name, + trainable=trainable, reuse=reuse) graph = maxout(graph, num_units=128, name='Maxout5a') + name = "Conv5" + trainable = is_trainable(name, trainable_variables) graph = slim.conv2d(graph, 256, [3, 3], activation_fn=tf.nn.relu, stride=1, weights_initializer=initializer, - scope='Conv5', + scope=name, + trainable=trainable, reuse=reuse) - end_points['conv5'] = graph + end_points[name] = graph graph = maxout(graph, num_units=128, - name='Maxout5') + name='Maxout5') graph = slim.max_pool2d(graph, [2, 2], stride=2, padding="SAME", scope='Pool4') graph = slim.flatten(graph, scope='flatten1') - end_points['flatten1'] = graph + end_points['flatten1'] = graph graph = slim.dropout(graph, keep_prob=0.5, scope='dropout1') + name = "fc1" + trainable = is_trainable(name, trainable_variables) prelogits = slim.fully_connected(graph, 512, - weights_initializer=initializer, - activation_fn=tf.nn.relu, - scope='fc1', - reuse=reuse) + weights_initializer=initializer, + activation_fn=tf.nn.relu, + scope=name, + trainable=trainable, + reuse=reuse) end_points['fc1'] = prelogits return prelogits, end_points - diff --git a/bob/learn/tensorflow/network/MLP.py b/bob/learn/tensorflow/network/MLP.py index 345dd5f..0f48e86 100755 --- a/bob/learn/tensorflow/network/MLP.py +++ b/bob/learn/tensorflow/network/MLP.py @@ -5,7 +5,7 @@ import tensorflow as tf -def mlp(inputs, output_shape, hidden_layers=[10], hidden_activation=tf.nn.tanh, output_activation=None, seed=10): +def mlp(inputs, output_shape, hidden_layers=[10], hidden_activation=tf.nn.tanh, output_activation=None, seed=10, **kwargs): """An MLP is a representation of a Multi-Layer Perceptron. This implementation is feed-forward and fully-connected. diff --git a/bob/learn/tensorflow/network/utils.py b/bob/learn/tensorflow/network/utils.py index a352e3f..f7abea4 100755 --- a/bob/learn/tensorflow/network/utils.py +++ b/bob/learn/tensorflow/network/utils.py @@ -14,3 +14,26 @@ def append_logits(graph, n_classes, reuse=False, l2_regularizer=0.001, stddev=weights_std), weights_regularizer=slim.l2_regularizer(l2_regularizer), scope='Logits', reuse=reuse) + + +def is_trainable(name, trainable_variables): + """ + Check if a variable is trainable or not + + Parameters + ---------- + + name: str + Layer name + + trainable_variables: list + List containing the variables or scopes to be trained. + If None, the variable/scope is trained + """ + + # If None, we train by default + if trainable_variables is None: + return True + + # Here is my choice to shutdown the whole scope + return name in trainable_variables diff --git a/bob/learn/tensorflow/script/train_generic.py b/bob/learn/tensorflow/script/train_generic.py index b9c2c02..c5d89c1 100644 --- a/bob/learn/tensorflow/script/train_generic.py +++ b/bob/learn/tensorflow/script/train_generic.py @@ -55,6 +55,7 @@ def main(argv=None): defaults = docopt(docs, argv=[""]) args = docopt(docs, argv=argv, version=version) config_files = args[''] + config = read_config_file(config_files) # optional arguments diff --git a/bob/learn/tensorflow/test/test_dataset.py b/bob/learn/tensorflow/test/test_dataset.py index 31a0cea..2d1b2fe 100755 --- a/bob/learn/tensorflow/test/test_dataset.py +++ b/bob/learn/tensorflow/test/test_dataset.py @@ -14,7 +14,6 @@ batch_size = 2 validation_batch_size = 250 epochs = 1 - # Trainer logits filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), @@ -34,8 +33,8 @@ labels = [0, 0, 0, 0, 0, 0, def test_siamese_dataset(): - - data, label = siamese_batch(filenames, labels, data_shape, data_type, 2, per_image_normalization=False, output_shape=output_shape) + data, label = siamese_batch(filenames, labels, data_shape, data_type, 2, per_image_normalization=False, + output_shape=output_shape) with tf.Session() as session: d, l = session.run([data, label]) @@ -45,8 +44,8 @@ def test_siamese_dataset(): def test_triplet_dataset(): - - data = triplet_batch(filenames, labels, data_shape, data_type, 2, per_image_normalization=False, output_shape=output_shape) + data = triplet_batch(filenames, labels, data_shape, data_type, 2, per_image_normalization=False, + output_shape=output_shape) with tf.Session() as session: d = session.run([data])[0] assert len(d.keys()) == 3 diff --git a/bob/learn/tensorflow/test/test_estimator_onegraph.py b/bob/learn/tensorflow/test/test_estimator_onegraph.py index 7b5cfba..84f5131 100755 --- a/bob/learn/tensorflow/test/test_estimator_onegraph.py +++ b/bob/learn/tensorflow/test/test_estimator_onegraph.py @@ -7,13 +7,11 @@ import tensorflow as tf from bob.learn.tensorflow.network import dummy from bob.learn.tensorflow.estimators import Logits, LogitsCenterLoss -from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, batch_data_and_labels, shuffle_data_and_labels_image_augmentation +from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, batch_data_and_labels, \ + shuffle_data_and_labels_image_augmentation - -from bob.learn.tensorflow.dataset import append_image_augmentation from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator -from bob.learn.tensorflow.utils import reproducible from bob.learn.tensorflow.loss import mean_cross_entropy_loss import numpy @@ -21,9 +19,8 @@ import numpy import shutil import os - tfrecord_train = "./train_mnist.tfrecord" -tfrecord_validation = "./validation_mnist.tfrecord" +tfrecord_validation = "./validation_mnist.tfrecord" model_dir = "./temp" learning_rate = 0.1 @@ -40,12 +37,12 @@ def test_logitstrainer(): try: embedding_validation = False trainer = Logits(model_dir=model_dir, - architecture=dummy, - optimizer=tf.train.GradientDescentOptimizer(learning_rate), - n_classes=10, - loss_op=mean_cross_entropy_loss, - embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size) + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + loss_op=mean_cross_entropy_loss, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size) run_logitstrainer_mnist(trainer, augmentation=True) finally: try: @@ -53,19 +50,19 @@ def test_logitstrainer(): os.unlink(tfrecord_validation) shutil.rmtree(model_dir, ignore_errors=True) except Exception: - pass + pass def test_logitstrainer_embedding(): try: embedding_validation = True trainer = Logits(model_dir=model_dir, - architecture=dummy, - optimizer=tf.train.GradientDescentOptimizer(learning_rate), - n_classes=10, - loss_op=mean_cross_entropy_loss, - embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size) + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + loss_op=mean_cross_entropy_loss, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size) run_logitstrainer_mnist(trainer) finally: @@ -74,25 +71,24 @@ def test_logitstrainer_embedding(): os.unlink(tfrecord_validation) shutil.rmtree(model_dir, ignore_errors=True) except Exception: - pass + pass def test_logitstrainer_centerloss(): - try: embedding_validation = False run_config = tf.estimator.RunConfig() run_config = run_config.replace(save_checkpoints_steps=1000) trainer = LogitsCenterLoss( - model_dir=model_dir, - architecture=dummy, - optimizer=tf.train.GradientDescentOptimizer(learning_rate), - n_classes=10, - embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size, - factor=0.01, - config=run_config) - + model_dir=model_dir, + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size, + factor=0.01, + config=run_config) + run_logitstrainer_mnist(trainer) # Checking if the centers were updated @@ -101,9 +97,9 @@ def test_logitstrainer_centerloss(): saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True) saver.restore(sess, tf.train.latest_checkpoint(model_dir)) centers = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="center_loss/centers:0")[0] - assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0 + assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0 + - finally: try: os.unlink(tfrecord_train) @@ -117,33 +113,32 @@ def test_logitstrainer_centerloss_embedding(): try: embedding_validation = True trainer = LogitsCenterLoss( - model_dir=model_dir, - architecture=dummy, - optimizer=tf.train.GradientDescentOptimizer(learning_rate), - n_classes=10, - embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size, - factor=0.01) + model_dir=model_dir, + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size, + factor=0.01) run_logitstrainer_mnist(trainer) - + # Checking if the centers were updated sess = tf.Session() checkpoint_path = tf.train.get_checkpoint_state(model_dir).model_checkpoint_path saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True) saver.restore(sess, tf.train.latest_checkpoint(model_dir)) centers = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="center_loss/centers:0")[0] - assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0 + assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0 finally: try: os.unlink(tfrecord_train) os.unlink(tfrecord_validation) shutil.rmtree(model_dir, ignore_errors=True) except Exception: - pass + pass def run_logitstrainer_mnist(trainer, augmentation=False): - # Cleaning up tf.reset_default_graph() assert len(tf.global_variables()) == 0 @@ -155,25 +150,24 @@ def run_logitstrainer_mnist(trainer, augmentation=False): def input_fn(): if augmentation: - return shuffle_data_and_labels_image_augmentation(tfrecord_train, data_shape, data_type, batch_size, epochs=epochs) + return shuffle_data_and_labels_image_augmentation(tfrecord_train, data_shape, data_type, batch_size, + epochs=epochs) else: return shuffle_data_and_labels(tfrecord_train, data_shape, data_type, batch_size, epochs=epochs) - def input_fn_validation(): return batch_data_and_labels(tfrecord_validation, data_shape, data_type, validation_batch_size, epochs=1000) - + hooks = [LoggerHookEstimator(trainer, 16, 300), tf.train.SummarySaverHook(save_steps=1000, output_dir=model_dir, scaffold=tf.train.Scaffold(), - summary_writer=tf.summary.FileWriter(model_dir) )] + summary_writer=tf.summary.FileWriter(model_dir))] trainer.train(input_fn, steps=steps, hooks=hooks) - if not trainer.embedding_validation: acc = trainer.evaluate(input_fn_validation) assert acc['accuracy'] > 0.40 @@ -184,4 +178,3 @@ def run_logitstrainer_mnist(trainer, augmentation=False): # Cleaning up tf.reset_default_graph() assert len(tf.global_variables()) == 0 - diff --git a/bob/learn/tensorflow/test/test_estimator_siamese.py b/bob/learn/tensorflow/test/test_estimator_siamese.py index 37403bc..ce96b3a 100755 --- a/bob/learn/tensorflow/test/test_estimator_siamese.py +++ b/bob/learn/tensorflow/test/test_estimator_siamese.py @@ -12,18 +12,13 @@ from bob.learn.tensorflow.dataset.image import shuffle_data_and_labels_image_aug from bob.learn.tensorflow.loss import contrastive_loss, mean_cross_entropy_loss from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator -from bob.learn.tensorflow.utils import reproducible from .test_estimator_transfer import dummy_adapted import pkg_resources - -import numpy import shutil -import os - tfrecord_train = "./train_mnist.tfrecord" -tfrecord_validation = "./validation_mnist.tfrecord" +tfrecord_validation = "./validation_mnist.tfrecord" model_dir = "./temp" model_dir_adapted = "./temp2" @@ -36,26 +31,24 @@ validation_batch_size = 2 epochs = 1 steps = 5000 - # Data filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), - - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'), ] @@ -67,31 +60,70 @@ def test_siamesetrainer(): # Trainer logits try: trainer = Siamese(model_dir=model_dir, + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + loss_op=contrastive_loss, + validation_batch_size=validation_batch_size) + run_siamesetrainer(trainer) + finally: + try: + shutil.rmtree(model_dir, ignore_errors=True) + # pass + except Exception: + pass + + +def test_siamesetrainer_transfer(): + def logits_input_fn(): + return single_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, + output_shape=output_shape) + + # Trainer logits first than siamese + try: + + extra_checkpoint = {"checkpoint_path": model_dir, + "scopes": dict({"Dummy/": "Dummy/"}), + "trainable_variables": [] + } + + # LOGISTS + logits_trainer = Logits(model_dir=model_dir, architecture=dummy, optimizer=tf.train.GradientDescentOptimizer(learning_rate), - loss_op=contrastive_loss, + n_classes=2, + loss_op=mean_cross_entropy_loss, + embedding_validation=False, validation_batch_size=validation_batch_size) + logits_trainer.train(logits_input_fn, steps=steps) + + # NOW THE FUCKING SIAMESE + trainer = Siamese(model_dir=model_dir_adapted, + architecture=dummy_adapted, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + loss_op=contrastive_loss, + validation_batch_size=validation_batch_size, + extra_checkpoint=extra_checkpoint) run_siamesetrainer(trainer) finally: try: shutil.rmtree(model_dir, ignore_errors=True) - #pass + shutil.rmtree(model_dir_adapted, ignore_errors=True) except Exception: - pass + pass -def test_siamesetrainer_transfer(): - +def test_siamesetrainer_transfer_extraparams(): def logits_input_fn(): - return single_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, output_shape=output_shape) + return single_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, + output_shape=output_shape) # Trainer logits first than siamese try: - extra_checkpoint = {"checkpoint_path":model_dir, + extra_checkpoint = {"checkpoint_path": model_dir, "scopes": dict({"Dummy/": "Dummy/"}), - "is_trainable": False - } + "trainable_variables": ["Dummy"] + } # LOGISTS logits_trainer = Logits(model_dir=model_dir, @@ -101,6 +133,7 @@ def test_siamesetrainer_transfer(): loss_op=mean_cross_entropy_loss, embedding_validation=False, validation_batch_size=validation_batch_size) + logits_trainer.train(logits_input_fn, steps=steps) # NOW THE FUCKING SIAMESE @@ -114,30 +147,31 @@ def test_siamesetrainer_transfer(): finally: try: shutil.rmtree(model_dir, ignore_errors=True) - shutil.rmtree(model_dir_adapted, ignore_errors=True) + shutil.rmtree(model_dir_adapted, ignore_errors=True) except Exception: - pass + pass def run_siamesetrainer(trainer): - # Cleaning up tf.reset_default_graph() assert len(tf.global_variables()) == 0 def input_fn(): - return siamese_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, output_shape=output_shape, + return siamese_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, + output_shape=output_shape, random_flip=True, random_brightness=True, random_contrast=True, random_saturation=True) def input_validation_fn(): - return single_batch(filenames, labels, data_shape, data_type, validation_batch_size, epochs=10, output_shape=output_shape) + return single_batch(filenames, labels, data_shape, data_type, validation_batch_size, epochs=10, + output_shape=output_shape) hooks = [LoggerHookEstimator(trainer, batch_size, 300), tf.train.SummarySaverHook(save_steps=1000, output_dir=model_dir, scaffold=tf.train.Scaffold(), - summary_writer=tf.summary.FileWriter(model_dir) )] + summary_writer=tf.summary.FileWriter(model_dir))] trainer.train(input_fn, steps=1, hooks=hooks) @@ -147,4 +181,3 @@ def run_siamesetrainer(trainer): # Cleaning up tf.reset_default_graph() assert len(tf.global_variables()) == 0 - diff --git a/bob/learn/tensorflow/test/test_estimator_transfer.py b/bob/learn/tensorflow/test/test_estimator_transfer.py index 4358f43..d08adbc 100755 --- a/bob/learn/tensorflow/test/test_estimator_transfer.py +++ b/bob/learn/tensorflow/test/test_estimator_transfer.py @@ -6,25 +6,15 @@ import tensorflow as tf from bob.learn.tensorflow.network import dummy from bob.learn.tensorflow.estimators import Logits, LogitsCenterLoss - -from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, batch_data_and_labels, shuffle_data_and_labels_image_augmentation - - -from bob.learn.tensorflow.dataset import append_image_augmentation -from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord -from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator from bob.learn.tensorflow.utils import reproducible from bob.learn.tensorflow.loss import mean_cross_entropy_loss from .test_estimator_onegraph import run_logitstrainer_mnist -import numpy - import shutil import os - tfrecord_train = "./train_mnist.tfrecord" -tfrecord_validation = "./validation_mnist.tfrecord" +tfrecord_validation = "./validation_mnist.tfrecord" model_dir = "./temp" model_dir_adapted = "./temp2" @@ -37,68 +27,75 @@ epochs = 2 steps = 5000 -def dummy_adapted(inputs, reuse=False, mode = tf.estimator.ModeKeys.TRAIN, trainable_variables=True): +def dummy_adapted(inputs, reuse=False, mode=tf.estimator.ModeKeys.TRAIN, trainable_variables=None, **kwargs): """ Create all the necessary variables for this CNN - **Parameters** + Parameters + ---------- inputs: reuse: + + mode: + + trainable_variables: """ slim = tf.contrib.slim - graph, end_points = dummy(inputs, reuse=reuse, mode = mode, trainable_variables=trainable_variables) + graph, end_points = dummy(inputs, reuse=reuse, mode=mode, trainable_variables=trainable_variables) initializer = tf.contrib.layers.xavier_initializer() with tf.variable_scope('Adapted', reuse=reuse): + name = 'fc2' graph = slim.fully_connected(graph, 50, weights_initializer=initializer, activation_fn=tf.nn.relu, - scope='fc2') - end_points['fc2'] = graph + scope=name, + trainable=True) + end_points[name] = graph + name = 'fc3' graph = slim.fully_connected(graph, 25, weights_initializer=initializer, activation_fn=None, - scope='fc3') - end_points['fc3'] = graph - + scope=name, + trainable=True) + end_points[name] = graph return graph, end_points def test_logitstrainer(): # Trainer logits - try: + try: embedding_validation = False - trainer = Logits(model_dir=model_dir, - architecture=dummy, - optimizer=tf.train.GradientDescentOptimizer(learning_rate), - n_classes=10, - loss_op=mean_cross_entropy_loss, - embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size) + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + loss_op=mean_cross_entropy_loss, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size) run_logitstrainer_mnist(trainer, augmentation=True) del trainer ## Again - extra_checkpoint = {"checkpoint_path":"./temp", + extra_checkpoint = {"checkpoint_path": "./temp", "scopes": dict({"Dummy/": "Dummy/"}), - "is_trainable": False - } + "trainable_variables": [] + } trainer = Logits(model_dir=model_dir_adapted, - architecture=dummy_adapted, - optimizer=tf.train.GradientDescentOptimizer(learning_rate), - n_classes=10, - loss_op=mean_cross_entropy_loss, - embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size, - extra_checkpoint=extra_checkpoint - ) - + architecture=dummy_adapted, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + loss_op=mean_cross_entropy_loss, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size, + extra_checkpoint=extra_checkpoint + ) + run_logitstrainer_mnist(trainer, augmentation=True) finally: @@ -114,33 +111,33 @@ def test_logitstrainer(): def test_logitstrainer_center_loss(): # Trainer logits - try: + try: embedding_validation = False - + trainer = LogitsCenterLoss(model_dir=model_dir, - architecture=dummy, - optimizer=tf.train.GradientDescentOptimizer(learning_rate), - n_classes=10, - embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size) + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size) run_logitstrainer_mnist(trainer, augmentation=True) del trainer ## Again - extra_checkpoint = {"checkpoint_path":"./temp", + extra_checkpoint = {"checkpoint_path": "./temp", "scopes": dict({"Dummy/": "Dummy/"}), - "is_trainable": False - } + "trainable_variables": ["Dummy"] + } trainer = LogitsCenterLoss(model_dir=model_dir_adapted, - architecture=dummy_adapted, - optimizer=tf.train.GradientDescentOptimizer(learning_rate), - n_classes=10, - embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size, - extra_checkpoint=extra_checkpoint - ) - + architecture=dummy_adapted, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size, + extra_checkpoint=extra_checkpoint + ) + run_logitstrainer_mnist(trainer, augmentation=True) finally: @@ -151,4 +148,3 @@ def test_logitstrainer_center_loss(): shutil.rmtree(model_dir_adapted, ignore_errors=True) except Exception: pass - diff --git a/bob/learn/tensorflow/test/test_estimator_triplet.py b/bob/learn/tensorflow/test/test_estimator_triplet.py index 5c68ece..bbca851 100755 --- a/bob/learn/tensorflow/test/test_estimator_triplet.py +++ b/bob/learn/tensorflow/test/test_estimator_triplet.py @@ -15,13 +15,10 @@ from bob.learn.tensorflow.utils import reproducible import pkg_resources from .test_estimator_transfer import dummy_adapted -import numpy import shutil -import os - tfrecord_train = "./train_mnist.tfrecord" -tfrecord_validation = "./validation_mnist.tfrecord" +tfrecord_validation = "./validation_mnist.tfrecord" model_dir = "./temp" model_dir_adapted = "./temp2" @@ -34,26 +31,24 @@ validation_batch_size = 2 epochs = 1 steps = 5000 - # Data filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), - - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'), pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'), ] @@ -73,23 +68,23 @@ def test_triplet_estimator(): finally: try: shutil.rmtree(model_dir, ignore_errors=True) - #pass + # pass except Exception: - pass + pass def test_triplettrainer_transfer(): - def logits_input_fn(): - return single_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, output_shape=output_shape) + return single_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, + output_shape=output_shape) # Trainer logits first than siamese try: - extra_checkpoint = {"checkpoint_path":model_dir, + extra_checkpoint = {"checkpoint_path": model_dir, "scopes": dict({"Dummy/": "Dummy/"}), - "is_trainable": False - } + "trainable_variables": [] + } # LOGISTS logits_trainer = Logits(model_dir=model_dir, @@ -118,24 +113,25 @@ def test_triplettrainer_transfer(): def run_triplet_estimator(trainer): - # Cleaning up tf.reset_default_graph() assert len(tf.global_variables()) == 0 def input_fn(): - return triplet_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, output_shape=output_shape, + return triplet_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, + output_shape=output_shape, random_flip=True, random_brightness=True, random_contrast=True, random_saturation=True) def input_validation_fn(): - return single_batch(filenames, labels, data_shape, data_type, validation_batch_size, epochs=10, output_shape=output_shape) + return single_batch(filenames, labels, data_shape, data_type, validation_batch_size, epochs=10, + output_shape=output_shape) hooks = [LoggerHookEstimator(trainer, batch_size, 300), tf.train.SummarySaverHook(save_steps=1000, output_dir=model_dir, scaffold=tf.train.Scaffold(), - summary_writer=tf.summary.FileWriter(model_dir) )] + summary_writer=tf.summary.FileWriter(model_dir))] trainer.train(input_fn, steps=steps, hooks=hooks) diff --git a/bob/learn/tensorflow/test/test_image_dataset.py b/bob/learn/tensorflow/test/test_image_dataset.py index e933dec..d3a63c5 100755 --- a/bob/learn/tensorflow/test/test_image_dataset.py +++ b/bob/learn/tensorflow/test/test_image_dataset.py @@ -10,14 +10,9 @@ from bob.learn.tensorflow.estimators import Logits, LogitsCenterLoss from bob.learn.tensorflow.dataset.image import shuffle_data_and_labels_image_augmentation import pkg_resources -from bob.learn.tensorflow.dataset import append_image_augmentation -from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator -from bob.learn.tensorflow.utils import reproducible from bob.learn.tensorflow.loss import mean_cross_entropy_loss -import numpy - import shutil import os @@ -33,17 +28,16 @@ steps = 5000 def test_logitstrainer_images(): - # Trainer logits try: embedding_validation = False trainer = Logits(model_dir=model_dir, - architecture=dummy, - optimizer=tf.train.GradientDescentOptimizer(learning_rate), - n_classes=10, - loss_op=mean_cross_entropy_loss, - embedding_validation=embedding_validation, - validation_batch_size=validation_batch_size) + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + loss_op=mean_cross_entropy_loss, + embedding_validation=embedding_validation, + validation_batch_size=validation_batch_size) run_logitstrainer_images(trainer) finally: try: @@ -51,35 +45,35 @@ def test_logitstrainer_images(): os.unlink(tfrecord_validation) shutil.rmtree(model_dir, ignore_errors=True) except Exception: - pass - + pass + def run_logitstrainer_images(trainer): # Cleaning up tf.reset_default_graph() assert len(tf.global_variables()) == 0 - + filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), - pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')] + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')] labels = [0, 0, 1, 1] - + def input_fn(): - return shuffle_data_and_labels_image_augmentation(filenames,labels, data_shape, data_type, batch_size, epochs=epochs) - + return shuffle_data_and_labels_image_augmentation(filenames, labels, data_shape, data_type, batch_size, + epochs=epochs) def input_fn_validation(): return shuffle_data_and_labels_image_augmentation(filenames, labels, data_shape, data_type, validation_batch_size, epochs=1000) - + hooks = [LoggerHookEstimator(trainer, 16, 300), tf.train.SummarySaverHook(save_steps=1000, output_dir=model_dir, scaffold=tf.train.Scaffold(), - summary_writer=tf.summary.FileWriter(model_dir) )] + summary_writer=tf.summary.FileWriter(model_dir))] trainer.train(input_fn, steps=steps, hooks=hooks) @@ -93,4 +87,3 @@ def run_logitstrainer_images(trainer): # Cleaning up tf.reset_default_graph() assert len(tf.global_variables()) == 0 - -- GitLab