Skip to content
Snippets Groups Projects
Commit a102df67 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'predict' into 'master'

Changes to the biogenerator

See merge request !33
parents 5fb18ca1 4d08ad93
Branches
Tags
1 merge request!33Changes to the biogenerator
Pipeline #
......@@ -73,7 +73,8 @@ def append_image_augmentation(image, gray_scale=False,
if output_shape is not None:
assert len(output_shape) == 2
image = tf.image.resize_image_with_crop_or_pad(image, output_shape[0], output_shape[1])
image = tf.image.resize_image_with_crop_or_pad(
image, output_shape[0], output_shape[1])
if random_flip:
image = tf.image.random_flip_left_right(image)
......@@ -136,15 +137,18 @@ def triplets_random_generator(input_data, input_labels):
input_labels = numpy.array(input_labels)
total_samples = input_data.shape[0]
indexes_per_labels = arrange_indexes_by_label(input_labels, possible_labels)
indexes_per_labels = arrange_indexes_by_label(
input_labels, possible_labels)
# searching for random triplets
offset_class = 0
for i in range(total_samples):
anchor_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[possible_labels[offset_class]]))], ...]
anchor_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(
len(indexes_per_labels[possible_labels[offset_class]]))], ...]
positive_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[possible_labels[offset_class]]))], ...]
positive_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(
len(indexes_per_labels[possible_labels[offset_class]]))], ...]
# Changing the class
offset_class += 1
......@@ -152,10 +156,11 @@ def triplets_random_generator(input_data, input_labels):
if offset_class == len(possible_labels):
offset_class = 0
negative_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[possible_labels[offset_class]]))], ...]
negative_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(
len(indexes_per_labels[possible_labels[offset_class]]))], ...]
append(str(anchor_sample), str(positive_sample), str(negative_sample))
#yield anchor, positive, negative
# yield anchor, positive, negative
return anchor, positive, negative
......@@ -191,13 +196,16 @@ def siamease_pairs_generator(input_data, input_labels):
# Filtering the samples by label and shuffling all the indexes
#indexes_per_labels = dict()
#for l in possible_labels:
# for l in possible_labels:
# indexes_per_labels[l] = numpy.where(input_labels == l)[0]
# numpy.random.shuffle(indexes_per_labels[l])
indexes_per_labels = arrange_indexes_by_label(input_labels, possible_labels)
indexes_per_labels = arrange_indexes_by_label(
input_labels, possible_labels)
left_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True)
right_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True)
left_possible_indexes = numpy.random.choice(
possible_labels, total_samples, replace=True)
right_possible_indexes = numpy.random.choice(
possible_labels, total_samples, replace=True)
genuine = True
for i in range(total_samples):
......@@ -207,10 +215,12 @@ def siamease_pairs_generator(input_data, input_labels):
class_index = left_possible_indexes[i]
# Now selecting the samples for the pair
left = input_data[indexes_per_labels[class_index][numpy.random.randint(len(indexes_per_labels[class_index]))]]
right = input_data[indexes_per_labels[class_index][numpy.random.randint(len(indexes_per_labels[class_index]))]]
left = input_data[indexes_per_labels[class_index][numpy.random.randint(
len(indexes_per_labels[class_index]))]]
right = input_data[indexes_per_labels[class_index][numpy.random.randint(
len(indexes_per_labels[class_index]))]]
append(left, right, 0)
#yield left, right, 0
# yield left, right, 0
else:
# Selecting the 2 classes
class_index = list()
......@@ -219,7 +229,7 @@ def siamease_pairs_generator(input_data, input_labels):
# Finding the right pair
j = i
# TODO: Lame solution. Fix this
while j < total_samples: # Here is an unidiretinal search for the negative pair
while j < total_samples: # Here is an unidiretinal search for the negative pair
if left_possible_indexes[i] != right_possible_indexes[j]:
class_index.append(right_possible_indexes[j])
break
......@@ -227,11 +237,12 @@ def siamease_pairs_generator(input_data, input_labels):
if j < total_samples:
# Now selecting the samples for the pair
left = input_data[indexes_per_labels[class_index[0]][numpy.random.randint(len(indexes_per_labels[class_index[0]]))]]
right = input_data[indexes_per_labels[class_index[1]][numpy.random.randint(len(indexes_per_labels[class_index[1]]))]]
left = input_data[indexes_per_labels[class_index[0]][numpy.random.randint(
len(indexes_per_labels[class_index[0]]))]]
right = input_data[indexes_per_labels[class_index[1]][numpy.random.randint(
len(indexes_per_labels[class_index[1]]))]]
append(left, right, 1)
genuine = not genuine
return left_data, right_data, labels
......@@ -296,3 +307,30 @@ def tf_repeat(tensor, repeats):
tiled_tensor = tf.tile(expanded_tensor, multiples=multiples)
repeated_tesnor = tf.reshape(tiled_tensor, tf.shape(tensor) * repeats)
return repeated_tesnor
def all_patches(image, label, key, size):
"""Extracts all patches of an image
Parameters
----------
image
The image should be channels_last format and already batched.
label
The label for the image
key
The key for the image
size : (int, int)
The height and width of the blocks.
Returns
-------
(blocks, label, key)
The non-overlapping blocks of size from image and labels and keys are
repeated.
"""
blocks, n_blocks = blocks_tensorflow(image, size)
# duplicate label and key as n_blocks
label = tf_repeat(label, [n_blocks])
key = tf_repeat(key, [n_blocks])
return blocks, label, key
import six
import tensorflow as tf
from bob.bio.base import read_original_data
import logging
logger = logging.getLogger(__name__)
def bio_generator(database, biofiles, load_data=None, biofile_to_label=None,
multiple_samples=False, repeat=False):
"""Returns a generator and its output types and shapes based on
bob.bio.base databases.
Parameters
class BioGenerator(object):
"""A generator class which wraps bob.bio.base databases so that they can
be used with tf.data.Dataset.from_generator
Attributes
----------
database : :any:`bob.bio.base.database.BioDatabase`
The database that you want to use.
biofile_to_label : :obj:`object`, optional
A callable with the signature of ``label = biofile_to_label(biofile)``.
By default -1 is returned as label.
biofiles : [:any:`bob.bio.base.database.BioFile`]
The list of the bio files .
database : :any:`bob.bio.base.database.BioDatabase`
The database that you want to use.
epoch : int
The number of epochs that have been passed so far.
keys : [str]
The keys of samples obtained by calling ``biofile.make_path("", "")``
labels : [int]
The labels obtained by calling ``label = biofile_to_label(biofile)``
load_data : :obj:`object`, optional
A callable with the signature of
``data = load_data(database, biofile)``.
:any:`bob.bio.base.read_original_data` is used by default.
biofile_to_label : :obj:`object`, optional
A callable with the signature of ``label = biofile_to_label(biofile)``.
By default -1 is returned as label.
:any:`bob.bio.base.read_original_data` is wrapped to be used by
default.
multiple_samples : bool, optional
If true, it assumes that the bio database's samples actually contain
multiple samples. This is useful for when you want to treat video
databases as image databases.
repeat : bool, optional
If True, the samples are repeated forever.
Returns
-------
generator : object
A generator function that when called will return the samples. The
samples will be like ``(data, label, key)``.
multiple samples. This is useful for when you want to for example treat
video databases as image databases.
output_types : (object, object, object)
The types of the returned samples.
output_shapes : (tf.TensorShape, tf.TensorShape, tf.TensorShape)
The shapes of the returned samples.
"""
if load_data is None:
def load_data(database, biofile):
data = read_original_data(
biofile,
database.original_directory,
database.original_extension)
return data
if biofile_to_label is None:
def biofile_to_label(biofile):
return -1
labels = (biofile_to_label(f) for f in biofiles)
keys = (str(f.make_path("", "")) for f in biofiles)
def generator():
while True:
for f, label, key in six.moves.zip(biofiles, labels, keys):
data = load_data(database, f)
# labels
if multiple_samples:
for d in data:
yield (d, label, key)
else:
yield (data, label, key)
if not repeat:
break
def __init__(self, database, biofiles, load_data=None,
biofile_to_label=None, multiple_samples=False):
if load_data is None:
def load_data(database, biofile):
data = read_original_data(
biofile,
database.original_directory,
database.original_extension)
return data
if biofile_to_label is None:
def biofile_to_label(biofile):
return -1
self.database = database
self.biofiles = list(biofiles)
self.load_data = load_data
self.biofile_to_label = biofile_to_label
self.multiple_samples = multiple_samples
self.epoch = 0
# load one data to get its type and shape
data = load_data(database, biofiles[0])
if multiple_samples:
try:
data = data[0]
except TypeError:
# if the data is a generator
data = six.next(data)
data = tf.convert_to_tensor(data)
self._output_types = (data.dtype, tf.int64, tf.string)
self._output_shapes = (
data.shape, tf.TensorShape([]), tf.TensorShape([]))
logger.info("Initializing a dataset with %d files and %s types "
"and %s shapes", len(self.biofiles), self.output_types,
self.output_shapes)
@property
def labels(self):
for f in self.biofiles:
yield self.biofile_to_label(f)
@property
def keys(self):
for f in self.biofiles:
yield str(f.make_path("", "")).encode('utf-8')
@property
def output_types(self):
return self._output_types
@property
def output_shapes(self):
return self._output_shapes
# load one data to get its type and shape
data = load_data(database, biofiles[0])
if multiple_samples:
try:
data = data[0]
except TypeError:
# if the data is a generator
data = six.next(data)
data = tf.convert_to_tensor(data)
output_types = (data.dtype, tf.int64, tf.string)
output_shapes = (data.shape, tf.TensorShape([]), tf.TensorShape([]))
def __call__(self):
"""A generator function that when called will return the samples.
return (generator, output_types, output_shapes)
Yields
------
(data, label, key) : tuple
A tuple containing the data, label, and the key.
"""
for f, label, key in six.moves.zip(
self.biofiles, self.labels, self.keys):
data = self.load_data(self.database, f)
# labels
if self.multiple_samples:
for d in data:
yield (d, label, key)
else:
yield (data, label, key)
self.epoch += 1
logger.info("Elapsed %d epochs", self.epoch)
......@@ -3,14 +3,6 @@
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
import threading
import os
import bob.io.base
import bob.core
from tensorflow.core.framework import summary_pb2
import time
#logger = bob.core.log.setup("bob.learn.tensorflow")
from bob.learn.tensorflow.network.utils import append_logits
from tensorflow.python.estimator import estimator
from bob.learn.tensorflow.utils import predict_using_tensors
......@@ -28,102 +20,88 @@ class Logits(estimator.Estimator):
The **architecture** function should follow the following pattern:
def my_beautiful_function(placeholder):
def my_beautiful_architecture(placeholder, **kwargs):
end_points = dict()
graph = convXX(placeholder)
end_points['conv'] = graph
....
return graph, end_points
end_points = dict()
graph = convXX(placeholder)
end_points['conv'] = graph
....
return graph, end_points
The **loss** function should follow the following pattern:
def my_beautiful_loss(logits, labels):
def my_beautiful_loss(logits, labels, **kwargs):
return loss_set_of_ops(logits, labels)
Variables, scopes... from other models can be loaded by the model_fn.
For that, please, wrap the the path of the OTHER checkpoint and the list
of variables in a dictionary with the key "load_variable_from_checkpoint" an provide them to the keyword `params`:
{"load_variable_from_checkpoint": {"checkpoint_path":"mypath",
"scopes":{"my_scope/": my_scope/}}}
Parameters
----------
**Parameters**
architecture:
Pointer to a function that builds the graph.
optimizer:
One of the tensorflow solvers (https://www.tensorflow.org/api_guides/python/train)
One of the tensorflow solvers
(https://www.tensorflow.org/api_guides/python/train)
- tf.train.GradientDescentOptimizer
- tf.train.AdagradOptimizer
- ....
config:
n_classes:
Number of classes of your problem. The logits will be appended in this class
Number of classes of your problem. The logits will be appended in this
class
loss_op:
Pointer to a function that computes the loss.
embedding_validation:
Run the validation using embeddings?? [default: False]
model_dir:
Model path
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
params:
Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info)
Extra params for the model function (please see
https://www.tensorflow.org/extend/estimators for more info)
extra_checkpoint: dict()
extra_checkpoint: dict
In case you want to use other model to initialize some variables.
This argument should be in the following format
extra_checkpoint = {"checkpoint_path": <YOUR_CHECKPOINT>,
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"is_trainable": <IF_THOSE_LOADED_VARIABLES_ARE_TRAINABLE>
}
extra_checkpoint = {
"checkpoint_path": <YOUR_CHECKPOINT>,
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"is_trainable": <IF_THOSE_LOADED_VARIABLES_ARE_TRAINABLE>
}
"""
def __init__(self,
architecture=None,
optimizer=None,
architecture,
optimizer,
loss_op,
n_classes,
config=None,
n_classes=0,
loss_op=None,
embedding_validation=False,
model_dir="",
validation_batch_size=None,
params=None,
extra_checkpoint=None
):
):
self.architecture = architecture
self.optimizer=optimizer
self.n_classes=n_classes
self.loss_op=loss_op
self.optimizer = optimizer
self.n_classes = n_classes
self.loss_op = loss_op
self.loss = None
self.embedding_validation = embedding_validation
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
if self.optimizer is None:
raise ValueError("Please specify a optimizer (https://www.tensorflow.org/api_guides/python/train) !!")
if self.loss_op is None:
raise ValueError("Please specify a function to build the loss !!")
if self.n_classes <= 0:
raise ValueError("Number of classes must be greated than 0")
def _model_fn(features, labels, mode, params, config):
check_features(features)
......@@ -164,28 +142,36 @@ class Logits(estimator.Estimator):
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = {
"embeddings": embeddings
"embeddings": embeddings,
"key": key,
}
else:
probabilities = tf.nn.softmax(logits, name="softmax_tensor")
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
"classes": tf.argmax(input=logits, axis=1),
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
# Add `softmax_tensor` to the graph. It is used for PREDICT
# and by the `logging_hook`.
"probabilities": probabilities,
"key": key,
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
return tf.estimator.EstimatorSpec(mode=mode,
predictions=predictions)
# IF Validation
self.loss = self.loss_op(logits, labels)
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
predictions_op = predict_using_tensors(
predictions["embeddings"], labels,
num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(
labels=labels, predictions=predictions_op)}
return tf.estimator.EstimatorSpec(
mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
else:
# Add evaluation metrics (for EVAL mode)
eval_metric_ops = {
......@@ -219,7 +205,8 @@ class LogitsCenterLoss(estimator.Estimator):
Pointer to a function that builds the graph.
optimizer:
One of the tensorflow solvers (https://www.tensorflow.org/api_guides/python/train)
One of the tensorflow solvers
(https://www.tensorflow.org/api_guides/python/train)
- tf.train.GradientDescentOptimizer
- tf.train.AdagradOptimizer
- ....
......@@ -227,7 +214,8 @@ class LogitsCenterLoss(estimator.Estimator):
config:
n_classes:
Number of classes of your problem. The logits will be appended in this class
Number of classes of your problem. The logits will be appended in this
class
loss_op:
Pointer to a function that computes the loss.
......@@ -241,10 +229,11 @@ class LogitsCenterLoss(estimator.Estimator):
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
params:
Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info)
Extra params for the model function (please see
https://www.tensorflow.org/extend/estimators for more info)
"""
def __init__(self,
......@@ -258,8 +247,8 @@ class LogitsCenterLoss(estimator.Estimator):
factor=0.01,
validation_batch_size=None,
params=None,
extra_checkpoint=None,
):
extra_checkpoint=None,
):
self.architecture = architecture
self.optimizer = optimizer
......@@ -271,10 +260,13 @@ class LogitsCenterLoss(estimator.Estimator):
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
raise ValueError(
"Please specify a function to build the architecture !!")
if self.optimizer is None:
raise ValueError("Please specify a optimizer (https://www.tensorflow.org/api_guides/python/train) !!")
raise ValueError(
"Please specify a optimizer (https://www.tensorflow.org/"
"api_guides/python/train) !!")
if self.n_classes <= 0:
raise ValueError("Number of classes must be greated than 0")
......@@ -323,7 +315,8 @@ class LogitsCenterLoss(estimator.Estimator):
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = {
"embeddings": embeddings
"embeddings": embeddings,
"key": key,
}
else:
predictions = {
......@@ -331,7 +324,8 @@ class LogitsCenterLoss(estimator.Estimator):
"classes": tf.argmax(input=logits, axis=1),
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
"probabilities": tf.nn.softmax(logits, name="softmax_tensor"),
"key": key,
}
if mode == tf.estimator.ModeKeys.PREDICT:
......@@ -343,8 +337,10 @@ class LogitsCenterLoss(estimator.Estimator):
self.loss = loss_dict['loss']
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
predictions_op = predict_using_tensors(
predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(
labels=labels, predictions=predictions_op)}
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
else:
......
import tensorflow as tf
def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN,
kernerl_size=(3, 3), n_classes=2,
data_format='channels_last'):
def base_architecture(input_layer, mode, kernerl_size, data_format, **kwargs):
# Keep track of all the endpoints
endpoints = {}
# Convolutional Layer #1
# Computes 32 features using a kernerl_size filter with ReLU activation.
# Computes 32 features using a kernerl_size filter with ReLU
# activation.
# Padding is added to preserve width and height.
conv1 = tf.layers.conv2d(
inputs=input_layer,
......@@ -22,8 +20,8 @@ def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN,
# Pooling Layer #1
# First max pooling layer with a 2x2 filter and stride of 2
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2,
data_format=data_format)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2],
strides=2, data_format=data_format)
endpoints['pool1'] = pool1
# Convolutional Layer #2
......@@ -40,8 +38,8 @@ def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN,
# Pooling Layer #2
# Second max pooling layer with a 2x2 filter and stride of 2
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2,
data_format=data_format)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2],
strides=2, data_format=data_format)
endpoints['pool2'] = pool2
# Flatten tensor into a batch of vectors
......@@ -57,14 +55,26 @@ def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN,
# Add dropout operation; 0.6 probability that element will be kept
dropout = tf.layers.dropout(
inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
inputs=dense, rate=0.4,
training=mode == tf.estimator.ModeKeys.TRAIN)
endpoints['dropout'] = dropout
# Logits layer
# Input Tensor Shape: [batch_size, 1024]
# Output Tensor Shape: [batch_size, 2]
logits = tf.layers.dense(inputs=dropout, units=n_classes)
endpoints['logits'] = logits
return dropout, endpoints
def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN,
kernerl_size=(3, 3), n_classes=2,
data_format='channels_last', reuse=False, **kwargs):
with tf.variable_scope('SimpleCNN', reuse=reuse):
dropout, endpoints = base_architecture(
input_layer, mode, kernerl_size, data_format)
# Logits layer
# Input Tensor Shape: [batch_size, 1024]
# Output Tensor Shape: [batch_size, n_classes]
logits = tf.layers.dense(inputs=dropout, units=n_classes)
endpoints['logits'] = logits
return logits, endpoints
......@@ -72,7 +82,7 @@ def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN,
def model_fn(features, labels, mode, params=None, config=None):
"""Model function for CNN."""
data = features['data']
keys = features['key']
key = features['key']
params = params or {}
learning_rate = params.get('learning_rate', 1e-5)
......@@ -92,7 +102,7 @@ def model_fn(features, labels, mode, params=None, config=None):
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor"),
'keys': keys,
'key': key,
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
......@@ -117,7 +127,6 @@ def model_fn(features, labels, mode, params=None, config=None):
else:
train_op = None
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
......
......@@ -3,12 +3,14 @@
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
slim = tf.contrib.slim
import tensorflow.contrib.slim as slim
def append_logits(graph, n_classes, reuse=False, l2_regularizer=0.001, weights_std=0.1):
return slim.fully_connected(graph, n_classes, activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(stddev=weights_std),
weights_regularizer=slim.l2_regularizer(l2_regularizer),
scope='Logits', reuse=reuse)
def append_logits(graph, n_classes, reuse=False, l2_regularizer=0.001,
weights_std=0.1):
return slim.fully_connected(
graph, n_classes, activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(
stddev=weights_std),
weights_regularizer=slim.l2_regularizer(l2_regularizer),
scope='Logits', reuse=reuse)
......@@ -89,7 +89,7 @@ An example configuration for a trained model and its evaluation could be::
# output_shapes)`` line is mandatory in the function below. You have to
# create it in your configuration file since you want it to be created in
# the same graph as your model.
def bio_predict_input_fn(generator,output_types, output_shapes):
def bio_predict_input_fn(generator, output_types, output_shapes):
def input_fn():
dataset = tf.data.Dataset.from_generator(generator, output_types,
output_shapes)
......@@ -116,7 +116,7 @@ from bob.bio.base.utils import read_config_file, save
from bob.bio.base.tools.grid import indices
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.learn.tensorflow.dataset.bio import bio_generator
from bob.learn.tensorflow.dataset.bio import BioGenerator
from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
......@@ -140,9 +140,20 @@ def make_output_path(output_dir, key):
return os.path.join(output_dir, key + '.hdf5')
def non_existing_files(paths, force=False):
if force:
for i in range(len(paths)):
yield i
return
for i, path in enumerate(paths):
if not os.path.isfile(path):
yield i
def save_predictions(pool, output_dir, key, pred_buffer):
outpath = make_output_path(output_dir, key)
create_directories_safe(os.path.dirname(outpath))
logger.debug("Saving predictions for %s", key)
pool.apply_async(save, (np.mean(pred_buffer[key], axis=0), outpath))
......@@ -183,16 +194,33 @@ def main(argv=None):
output_dir = get_from_config_or_commandline(
config, 'output_dir', args, defaults, False)
assert len(biofiles), "biofiles are empty!"
logger.info("number_of_parallel_jobs: %d", number_of_parallel_jobs)
if number_of_parallel_jobs > 1:
start, end = indices(biofiles, number_of_parallel_jobs)
biofiles = biofiles[start:end]
generator, output_types, output_shapes = bio_generator(
# filter the existing files
paths = (make_output_path(output_dir, f.make_path("", ""))
for f in biofiles)
indexes = non_existing_files(paths, force)
biofiles = [biofiles[i] for i in indexes]
if len(biofiles) == 0:
logger.warning(
"The biofiles are empty after checking for existing files.")
return
generator = BioGenerator(
database, biofiles, load_data=load_data,
biofile_to_label=None, multiple_samples=multiple_samples, force=force)
multiple_samples=multiple_samples)
predict_input_fn = bio_predict_input_fn(
generator, generator.output_types, generator.output_shapes)
predict_input_fn = bio_predict_input_fn(generator,
output_types, output_shapes)
if checkpoint_path:
logger.info("Restoring the model from %s", checkpoint_path)
predictions = estimator.predict(
predict_input_fn,
......@@ -201,6 +229,8 @@ def main(argv=None):
checkpoint_path=checkpoint_path,
)
logger.info("Saving the predictions in %s", output_dir)
pool = Pool()
try:
pred_buffer = defaultdict(list)
......@@ -215,9 +245,8 @@ def main(argv=None):
else:
save_predictions(pool, output_dir, last_key, pred_buffer)
last_key = key
# else below is for the for loop
else:
save_predictions(pool, output_dir, key, pred_buffer)
# save the final returned key as well:
save_predictions(pool, output_dir, key, pred_buffer)
finally:
pool.close()
pool.join()
......
......@@ -3,18 +3,25 @@
"""Trains networks using Tensorflow estimators.
Usage:
%(prog)s [options] <config_files>...
%(prog)s --help
%(prog)s --version
%(prog)s [-v...] [options] <config_files>...
%(prog)s --help
%(prog)s --version
Arguments:
<config_files> The configuration files. The configuration files are loaded
in order and they need to have several objects inside
totally. See below for explanation.
<config_files> The configuration files. The
configuration files are loaded in order
and they need to have several objects
inside totally. See below for
explanation.
Options:
-h --help show this help message and exit
--version show version and exit
-h --help Show this help message and exit
--version Show version and exit
-v, --verbose Increases the output verbosity level
-s N, --steps N The number of steps to train.
-m N, --max-steps N The maximum number of steps to train.
This is a limit for global step which
continues in separate runs.
The configuration files should have the following objects totally:
......@@ -26,11 +33,6 @@ The configuration files should have the following objects totally:
## Optional objects:
hooks
steps
max_steps
For an example configuration, please see:
bob.learn.tensorflow/bob/learn/tensorflow/examples/mnist/mnist_config.py
"""
from __future__ import absolute_import
from __future__ import division
......@@ -38,6 +40,10 @@ from __future__ import print_function
# import pkg_resources so that bob imports work properly:
import pkg_resources
from bob.bio.base.utils import read_config_file
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
def main(argv=None):
......@@ -46,17 +52,27 @@ def main(argv=None):
import sys
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version
defaults = docopt(docs, argv=[""])
args = docopt(docs, argv=argv, version=version)
config_files = args['<config_files>']
config = read_config_file(config_files)
# optional arguments
verbosity = get_from_config_or_commandline(
config, 'verbose', args, defaults)
max_steps = get_from_config_or_commandline(
config, 'max_steps', args, defaults)
steps = get_from_config_or_commandline(
config, 'steps', args, defaults)
hooks = getattr(config, 'hooks', None)
# Sets-up logging
set_verbosity_level(logger, verbosity)
# required arguments
estimator = config.estimator
train_input_fn = config.train_input_fn
hooks = getattr(config, 'hooks', None)
steps = getattr(config, 'steps', None)
max_steps = getattr(config, 'max_steps', None)
# Train
estimator.train(input_fn=train_input_fn, hooks=hooks, steps=steps,
max_steps=max_steps)
......
......@@ -26,7 +26,8 @@ rn.seed(12345)
# For further details, see:
# https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res
session_config = tf.ConfigProto(intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
inter_op_parallelism_threads=1,
log_device_placement=True)
# The below tf.set_random_seed() will make random number generation
# in the TensorFlow backend have a well-defined initial state.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment