Skip to content
Snippets Groups Projects
Commit fcb2fc15 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'predict' into '40-adopt-to-the-estimators-api'

Add a prediction script

See merge request !22
parents ffb39107 c01b17ba
Branches
Tags
2 merge requests!22Add a prediction script,!21Resolve "Adopt to the Estimators API"
Pipeline #
...@@ -17,16 +17,15 @@ from __future__ import absolute_import ...@@ -17,16 +17,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from bob.learn.tensorflow.utils.reproducible import session_conf # create reproducible nets:
from bob.learn.tensorflow.utils.reproducible import run_config
import tensorflow as tf import tensorflow as tf
from bob.db.mnist import Database
model_dir = '/tmp/mnist_model' model_dir = '/tmp/mnist_model'
train_tfrecords = ['/tmp/mnist_data/train.tfrecords'] train_tfrecords = ['/tmp/mnist_data/train.tfrecords']
eval_tfrecords = ['/tmp/mnist_data/test.tfrecords'] eval_tfrecords = ['/tmp/mnist_data/test.tfrecords']
# by default create reproducible nets:
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(session_config=session_conf)
run_config = run_config.replace(keep_checkpoint_max=10**3) run_config = run_config.replace(keep_checkpoint_max=10**3)
run_config = run_config.replace(save_checkpoints_secs=60) run_config = run_config.replace(save_checkpoints_secs=60)
...@@ -39,22 +38,27 @@ def input_fn(mode, batch_size=1): ...@@ -39,22 +38,27 @@ def input_fn(mode, batch_size=1):
features = tf.parse_single_example( features = tf.parse_single_example(
serialized_example, serialized_example,
features={ features={
'image_raw': tf.FixedLenFeature([], tf.string), 'data': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64), 'label': tf.FixedLenFeature([], tf.int64),
'key': tf.FixedLenFeature([], tf.string),
}) })
image = tf.decode_raw(features['image_raw'], tf.uint8) image = tf.decode_raw(features['data'], tf.uint8)
image.set_shape([28 * 28]) image.set_shape([28 * 28])
# Normalize the values of the image from the range # Normalize the values of the image from the range
# [0, 255] to [-0.5, 0.5] # [0, 255] to [-0.5, 0.5]
image = tf.cast(image, tf.float32) / 255 - 0.5 image = tf.cast(image, tf.float32) / 255 - 0.5
label = tf.cast(features['label'], tf.int32) label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10)
key = tf.cast(features['key'], tf.string)
return image, tf.one_hot(label, 10), key
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
tfrecords_files = train_tfrecords tfrecords_files = train_tfrecords
elif mode == tf.estimator.ModeKeys.EVAL:
tfrecords_files = eval_tfrecords
else: else:
assert mode == tf.estimator.ModeKeys.EVAL, 'invalid mode' assert mode == tf.estimator.ModeKeys.PREDICT, 'invalid mode'
tfrecords_files = eval_tfrecords tfrecords_files = eval_tfrecords
for tfrecords_file in tfrecords_files: for tfrecords_file in tfrecords_files:
...@@ -73,9 +77,9 @@ def input_fn(mode, batch_size=1): ...@@ -73,9 +77,9 @@ def input_fn(mode, batch_size=1):
dataset = dataset.map( dataset = dataset.map(
example_parser, num_threads=1, output_buffer_size=batch_size) example_parser, num_threads=1, output_buffer_size=batch_size)
dataset = dataset.batch(batch_size) dataset = dataset.batch(batch_size)
images, labels = dataset.make_one_shot_iterator().get_next() images, labels, keys = dataset.make_one_shot_iterator().get_next()
return images, labels return {'images': images, 'keys': keys}, labels
def train_input_fn(): def train_input_fn():
...@@ -86,6 +90,10 @@ def eval_input_fn(): ...@@ -86,6 +90,10 @@ def eval_input_fn():
return input_fn(tf.estimator.ModeKeys.EVAL) return input_fn(tf.estimator.ModeKeys.EVAL)
def predict_input_fn():
return input_fn(tf.estimator.ModeKeys.PREDICT)
def mnist_model(inputs, mode): def mnist_model(inputs, mode):
"""Takes the MNIST inputs and mode and outputs a tensor of logits.""" """Takes the MNIST inputs and mode and outputs a tensor of logits."""
# Input Layer # Input Layer
...@@ -164,13 +172,16 @@ def mnist_model(inputs, mode): ...@@ -164,13 +172,16 @@ def mnist_model(inputs, mode):
return logits return logits
def model_fn(features, labels, mode): def model_fn(features, labels=None, mode=tf.estimator.ModeKeys.TRAIN):
"""Model function for MNIST.""" """Model function for MNIST."""
keys = features['keys']
features = features['images']
logits = mnist_model(features, mode) logits = mnist_model(features, mode)
predictions = { predictions = {
'classes': tf.argmax(input=logits, axis=1), 'classes': tf.argmax(input=logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor') 'probabilities': tf.nn.softmax(logits, name='softmax_tensor'),
'keys': keys,
} }
if mode == tf.estimator.ModeKeys.PREDICT: if mode == tf.estimator.ModeKeys.PREDICT:
...@@ -202,3 +213,22 @@ def model_fn(features, labels, mode): ...@@ -202,3 +213,22 @@ def model_fn(features, labels, mode):
loss=loss, loss=loss,
train_op=train_op, train_op=train_op,
eval_metric_ops=metrics) eval_metric_ops=metrics)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
params=None, config=run_config)
output = train_tfrecords[0]
db = Database()
data, labels = db.data(groups='train')
# output = eval_tfrecords[0]
# db = Database()
# data, labels = db.data(groups='test')
samples = zip(data, labels, (str(i) for i in range(len(data))))
def reader(sample):
return sample
# Required objects:
# you need a database object that inherits from
# bob.bio.base.database.BioDatabase (PAD dbs work too)
database = Database()
# the directory pointing to where the processed data is:
data_dir = '/idiap/temp/user/database_name/sub_directory/preprocessed'
# the directory to save the tfrecords in:
output_dir = '/idiap/temp/user/database_name/sub_directory'
# A function that converts a BioFile or a PadFile to a label:
# Example for PAD
def file_to_label(f):
return f.attack_type is None
# Example for Bio (You may want to run this script for groups=['world'] only
# in biometric recognition experiments.)
CLIENT_IDS = (str(f.client_id) for f in db.all_files(groups=groups))
CLIENT_IDS = list(set(CLIENT_IDS))
CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
def file_to_label(f):
return CLIENT_IDS[str(f.client_id)]
# Optional objects:
# The groups that you want to create tfrecords for. It should be a list of
# 'world' ('train' in bob.pad.base), 'dev', and 'eval' values. [default:
# 'world']
groups = ['world']
# you need a reader function that reads the preprocessed files. [default:
# bob.bio.base.utils.load]
reader = Preprocessor().read_data
reader = Extractor().read_feature
# or
from bob.bio.base.utils import load as reader
# or a reader that casts images to uint8:
def reader(path):
data = bob.bio.base.utils.load(path)
return data.astype("uint8")
# extension of the preprocessed files. [default: '.hdf5']
data_extension = '.hdf5'
# Shuffle the files before writing them into a tfrecords. [default: False]
shuffle = True
# Whether the each file contains one sample or more. [default: True] If
# this is False, the loaded samples from a file are iterated over and each
# of them is saved as an independent feature.
one_file_one_sample = True
...@@ -3,80 +3,86 @@ ...@@ -3,80 +3,86 @@
"""Converts Bio and PAD datasets to TFRecords file formats. """Converts Bio and PAD datasets to TFRecords file formats.
Usage: Usage:
%(prog)s <config_files>... [--allow-missing-files] %(prog)s [-v...] [options] <config_files>...
%(prog)s --help %(prog)s --help
%(prog)s --version %(prog)s --version
Arguments: Arguments:
<config_files> The configuration files. The configuration files are loaded <config_files> The configuration files. The configuration
in order and they need to have several objects inside files are loaded in order and they need to have
totally. See below for explanation. several objects inside totally. See below for
explanation.
Options: Options:
-h --help show this help message and exit -h --help Show this help message and exit
--version show version and exit --version Show version and exit
-o PATH, --output PATH Name of the output file.
--shuffle If provided, it will shuffle the samples.
--allow-failures If provided, the samples which fail to load are
ignored.
--multiple-samples If provided, it means that the data provided by
reader contains multiple samples with same
label and path.
-v, --verbose Increases the output verbosity level
The best way to use this script is to send it to the io-big queue if you are at The best way to use this script is to send it to the io-big queue if you are at
Idiap: Idiap:
$ jman submit -i -q q1d -- bin/python %(prog)s <config_files>... $ jman submit -i -q q1d -- %(prog)s <config_files>...
The configuration files should have the following objects totally: The configuration files should have the following objects totally::
## Required objects: # Required objects:
samples : a list of all samples that you want to write in the tfrecords
file. Whatever is inside this list is passed to the reader.
reader : a function with the signature of
`data, label, key = reader(sample)` which takes a sample and
returns the loaded data, the label of the data, and a key which
is unique for every sample.
# you need a database object that inherits from You can also provide the command line options in the configuration file too.
# bob.bio.base.database.BioDatabase (PAD dbs work too) It is needed to replace "-" with "_" when they are in the configuration file.
database = Database() The ones provided by command line overwrite the values of the config file.
# the directory pointing to where the processed data is: An example for mnist would be::
data_dir = '/idiap/temp/user/database_name/sub_directory/preprocessed'
# the directory to save the tfrecords in: from bob.db.mnist import Database
output_dir = '/idiap/temp/user/database_name/sub_directory' db = Database()
data, labels = db.data(groups='train')
# A function that converts a BioFile or a PadFile to a label: samples = zip(data, labels, (str(i) for i in range(len(data))))
# Example for PAD
def file_to_label(f):
return f.attack_type is None
# Example for Bio (You may want to run this script for groups=['world'] only def reader(sample):
# in biometric recognition experiments.) return sample
CLIENT_IDS = (str(f.client_id) for f in db.all_files(groups=groups))
CLIENT_IDS = list(set(CLIENT_IDS))
CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
def file_to_label(f): allow_failures = True
return CLIENT_IDS[str(f.client_id)] output = '/tmp/mnist_train.tfrecords'
shuffle = True
## Optional objects: An example for bob.bio.base would be::
# The groups that you want to create tfrecords for. It should be a list of from bob.bio.base.test.dummy.database import database
# 'world' ('train' in bob.pad.base), 'dev', and 'eval' values. [default: from bob.bio.base.test.dummy.preprocessor import preprocessor
# 'world']
groups = ['world']
# you need a reader function that reads the preprocessed files. [default: groups = 'dev'
# bob.bio.base.utils.load]
reader = Preprocessor().read_data
reader = Extractor().read_feature
# or
from bob.bio.base.utils import load as reader
# or a reader that casts images to uint8:
def reader(path):
data = bob.bio.base.utils.load(path)
return data.astype("uint8")
# extension of the preprocessed files. [default: '.hdf5'] samples = database.all_files(groups=groups)
data_extension = '.hdf5'
# Shuffle the files before writing them into a tfrecords. [default: False] CLIENT_IDS = (str(f.client_id) for f in database.all_files(groups=groups))
shuffle = True CLIENT_IDS = list(set(CLIENT_IDS))
CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
# Whether the each file contains one sample or more. [default: True] If
# this is False, the loaded samples from a file are iterated over and each def file_to_label(f):
# of them is saved as an independent feature. return CLIENT_IDS[str(f.client_id)]
one_file_one_sample = True
def reader(biofile):
data = preprocessor.read_original_data(
biofile, database.original_directory, database.original_extension)
label = file_to_label(biofile)
key = biofile.path
return (data, label, key)
""" """
from __future__ import absolute_import from __future__ import absolute_import
...@@ -85,10 +91,12 @@ from __future__ import print_function ...@@ -85,10 +91,12 @@ from __future__ import print_function
import random import random
# import pkg_resources so that bob imports work properly: # import pkg_resources so that bob imports work properly:
import pkg_resources import pkg_resources
import six
import tensorflow as tf import tensorflow as tf
from bob.io.base import create_directories_safe from bob.io.base import create_directories_safe
from bob.bio.base.utils import load, read_config_file from bob.bio.base.utils import read_config_file
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.core.log import setup, set_verbosity_level from bob.core.log import setup, set_verbosity_level
logger = setup(__name__) logger = setup(__name__)
...@@ -101,10 +109,11 @@ def int64_feature(value): ...@@ -101,10 +109,11 @@ def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def write_a_sample(writer, data, label, feature=None): def write_a_sample(writer, data, label, key, feature=None):
if feature is None: if feature is None:
feature = {'train/data': bytes_feature(data.tostring()), feature = {'data': bytes_feature(data.tostring()),
'train/label': int64_feature(label)} 'label': int64_feature(label),
'key': bytes_feature(key)}
example = tf.train.Example(features=tf.train.Features(feature=feature)) example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString()) writer.write(example.SerializeToString())
...@@ -116,55 +125,62 @@ def main(argv=None): ...@@ -116,55 +125,62 @@ def main(argv=None):
import sys import sys
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])} docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version version = pkg_resources.require('bob.learn.tensorflow')[0].version
defaults = docopt(docs, argv=[""])
args = docopt(docs, argv=argv, version=version) args = docopt(docs, argv=argv, version=version)
config_files = args['<config_files>'] config_files = args['<config_files>']
config = read_config_file(config_files) config = read_config_file(config_files)
allow_missing_files = args['--allow-missing-files']
# optional arguments
verbosity = get_from_config_or_commandline(
config, 'verbose', args, defaults)
allow_failures = get_from_config_or_commandline(
config, 'allow_failures', args, defaults)
multiple_samples = get_from_config_or_commandline(
config, 'multiple_samples', args, defaults)
shuffle = get_from_config_or_commandline(
config, 'shuffle', args, defaults)
# Sets-up logging # Sets-up logging
verbosity = getattr(config, 'verbose', 0)
set_verbosity_level(logger, verbosity) set_verbosity_level(logger, verbosity)
database = config.database # required arguments
data_dir, output_dir = config.data_dir, config.output_dir samples = config.samples
file_to_label = config.file_to_label reader = config.reader
output = get_from_config_or_commandline(
config, 'output', args, defaults, False)
reader = getattr(config, 'reader', load) if not output.endswith(".tfrecords"):
groups = getattr(config, 'groups', ['world']) output += ".tfrecords"
data_extension = getattr(config, 'data_extension', '.hdf5')
shuffle = getattr(config, 'shuffle', False)
one_file_one_sample = getattr(config, 'one_file_one_sample', True)
create_directories_safe(output_dir) create_directories_safe(os.path.dirname(output))
if not isinstance(groups, (list, tuple)):
groups = [groups]
for group in groups: n_samples = len(samples)
output_file = os.path.join(output_dir, '{}.tfrecords'.format(group)) sample_counter = 0
files = database.all_files(groups=group) with tf.python_io.TFRecordWriter(output) as writer:
if shuffle: if shuffle:
random.shuffle(files) random.shuffle(samples)
n_files = len(files) for i, sample in enumerate(samples):
with tf.python_io.TFRecordWriter(output_file) as writer: logger.info('Processing file %d out of %d', i + 1, n_samples)
for i, f in enumerate(files):
logger.info('Processing file %d out of %d', i + 1, n_files) data, label, key = reader(sample)
path = f.make_path(data_dir, data_extension) if data is None:
data = reader(path) if allow_failures:
if data is None: logger.debug("... Skipping `{0}`.".format(sample))
if allow_missing_files: continue
logger.debug("... Processing original data file '{0}' was not successful".format(path))
continue
else:
raise RuntimeError("Preprocessing of file '{0}' was not successful".format(path))
label = file_to_label(f)
if one_file_one_sample:
write_a_sample(writer, data, label)
else: else:
for sample in data: raise RuntimeError(
write_a_sample(writer, sample, label) "Reading failed for `{0}`".format(sample))
if multiple_samples:
for sample in data:
write_a_sample(writer, sample, label, key)
sample_counter += 1
else:
write_a_sample(writer, data, label, key)
sample_counter += 1
print("Wrote {} samples into the tfrecords file.".format(sample_counter))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -63,7 +63,7 @@ def main(argv=None): ...@@ -63,7 +63,7 @@ def main(argv=None):
model_fn = config.model_fn model_fn = config.model_fn
eval_input_fn = config.eval_input_fn eval_input_fn = config.eval_input_fn
eval_interval_secs = getattr(config, 'eval_interval_secs', 300) eval_interval_secs = getattr(config, 'eval_interval_secs', 60)
run_once = getattr(config, 'run_once', False) run_once = getattr(config, 'run_once', False)
run_config = getattr(config, 'run_config', None) run_config = getattr(config, 'run_config', None)
model_params = getattr(config, 'model_params', None) model_params = getattr(config, 'model_params', None)
...@@ -75,7 +75,7 @@ def main(argv=None): ...@@ -75,7 +75,7 @@ def main(argv=None):
nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
params=model_params, config=run_config) params=model_params, config=run_config)
if name: if name:
real_name = name + '_eval' real_name = 'eval_' + name
else: else:
real_name = 'eval' real_name = 'eval'
evaluated_file = os.path.join(nn.model_dir, real_name, 'evaluated') evaluated_file = os.path.join(nn.model_dir, real_name, 'evaluated')
...@@ -91,7 +91,12 @@ def main(argv=None): ...@@ -91,7 +91,12 @@ def main(argv=None):
continue continue
for checkpoint_path in ckpt.all_model_checkpoint_paths: for checkpoint_path in ckpt.all_model_checkpoint_paths:
global_step = str(get_global_step(checkpoint_path)) try:
global_step = str(get_global_step(checkpoint_path))
except Exception:
print('Failed to find global_step for checkpoint_path {}, '
'skipping ...'.format(checkpoint_path))
continue
if global_step in evaluated_steps: if global_step in evaluated_steps:
continue continue
......
#!/usr/bin/env python
"""Returns predictions of networks trained with
tf.train.MonitoredTrainingSession
Usage:
%(prog)s [-v...] [-k KEY]... [options] <config_files>...
%(prog)s --help
%(prog)s --version
Arguments:
<config_files> The configuration files. The configuration
files are loaded in order and they need to
have several objects inside totally. See
below for explanation.
Options:
-h --help Show this help message and exit
--version Show version and exit
-o PATH, --output-dir PATH Name of the output file.
-k KEY, --predict-keys KEY List of `str`, name of the keys to predict.
It is used if the
`EstimatorSpec.predictions` is a `dict`. If
`predict_keys` is used then rest of the
predictions will be filtered from the
dictionary. If `None`, returns all.
--checkpoint-path=<path> Path of a specific checkpoint to predict.
If `None`, the latest checkpoint in
`model_dir` is used.
-v, --verbose Increases the output verbosity level
The configuration files should have the following objects totally:
# Required objects:
estimator
predict_input_fn
# Optional objects:
hooks
For an example configuration, please see:
bob.learn.tensorflow/bob/learn/tensorflow/examples/mnist/mnist_config.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# import pkg_resources so that bob imports work properly:
import pkg_resources
import os
from multiprocessing import Pool
from collections import defaultdict
import numpy as np
from bob.io.base import create_directories_safe
from bob.bio.base.utils import read_config_file, save
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
def save_predictions(pool, output_dir, key, pred_buffer):
outpath = os.path.join(output_dir, key + '.hdf5')
create_directories_safe(os.path.dirname(outpath))
pool.apply_async(save, (np.mean(pred_buffer[key], axis=0), outpath))
def main(argv=None):
from docopt import docopt
import sys
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version
defaults = docopt(docs, argv=[""])
args = docopt(docs, argv=argv, version=version)
config_files = args['<config_files>']
config = read_config_file(config_files)
# optional arguments
verbosity = get_from_config_or_commandline(
config, 'verbose', args, defaults)
predict_keys = get_from_config_or_commandline(
config, 'predict_keys', args, defaults)
checkpoint_path = get_from_config_or_commandline(
config, 'checkpoint_path', args, defaults)
hooks = getattr(config, 'hooks', None)
# Sets-up logging
set_verbosity_level(logger, verbosity)
# required arguments
estimator = config.estimator
predict_input_fn = config.predict_input_fn
output_dir = get_from_config_or_commandline(
config, 'output_dir', args, defaults, False)
predictions = estimator.predict(
predict_input_fn,
predict_keys=predict_keys,
hooks=hooks,
checkpoint_path=checkpoint_path,
)
pool = Pool()
try:
pred_buffer = defaultdict(list)
for i, pred in enumerate(predictions):
key = pred['keys']
prob = pred.get('probabilities', pred.get('embeddings'))
pred_buffer[key].append(prob)
if i == 0:
last_key = key
if last_key == key:
continue
else:
save_predictions(pool, output_dir, last_key, pred_buffer)
last_key = key
# else below is for the for loop
else:
save_predictions(pool, output_dir, key, pred_buffer)
finally:
pool.close()
pool.join()
if __name__ == '__main__':
main()
...@@ -66,9 +66,7 @@ def main(argv=None): ...@@ -66,9 +66,7 @@ def main(argv=None):
if run_config is None: if run_config is None:
# by default create reproducible nets: # by default create reproducible nets:
from bob.learn.tensorflow.utils.reproducible import session_conf from bob.learn.tensorflow.utils.reproducible import run_config
run_config = tf.estimator.RunConfig()
run_config.replace(session_config=session_conf)
# Instantiate Estimator # Instantiate Estimator
nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
......
import os
from bob.bio.base.test.dummy.database import database from bob.bio.base.test.dummy.database import database
preprocessor = extractor = algorithm = 'dummy' from bob.bio.base.test.dummy.preprocessor import preprocessor
groups = ['dev']
temp_directory = result_directory = 'TEST_DIR' groups = 'dev'
sub_directory = 'sub_directory'
data_dir = os.path.join('TEST_DIR', sub_directory, 'preprocessed') files = database.all_files(groups=groups)
# the directory to save the tfrecords in:
output_dir = os.path.join('TEST_DIR', sub_directory)
CLIENT_IDS = (str(f.client_id) for f in database.all_files(groups=groups)) CLIENT_IDS = (str(f.client_id) for f in database.all_files(groups=groups))
CLIENT_IDS = list(set(CLIENT_IDS)) CLIENT_IDS = list(set(CLIENT_IDS))
...@@ -18,3 +12,11 @@ CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS)))) ...@@ -18,3 +12,11 @@ CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
def file_to_label(f): def file_to_label(f):
return CLIENT_IDS[str(f.client_id)] return CLIENT_IDS[str(f.client_id)]
def reader(biofile):
data = preprocessor.read_original_data(
biofile, database.original_directory, database.original_extension)
label = file_to_label(biofile)
key = biofile.path
return (data, label, key)
...@@ -4,3 +4,4 @@ from .session import Session ...@@ -4,3 +4,4 @@ from .session import Session
from . import hooks from . import hooks
from . import eval from . import eval
from . import tfrecords from . import tfrecords
from . import commandline
def get_from_config_or_commandline(config, keyword, args, defaults,
default_is_valid=True):
"""Takes the value from command line, config file, and default value with
this precedence.
Only several command line options can be used with this function:
- boolean flags
- repeating flags (like --verbose)
- options where the user will never provide the default value through
command line. For example when [default: None]
Parameters
----------
config : object
The loaded config files.
keyword : str
The keyword to load from the config file or through command line.
args : dict
The arguments parsed with docopt.
defaults : dict
The arguments parsed with docopt when ``argv=[]``.
default_is_valid : bool, optional
If False, will raise an exception if the final parsed value is the
default value.
Returns
-------
object
The bool or integer value of the corresponding keyword.
Example
-------
>>> from bob.bio.base.utils import read_config_file
>>> defaults = docopt(docs, argv=[""])
>>> args = docopt(docs, argv=argv)
>>> config_files = args['<config_files>']
>>> config = read_config_file(config_files)
>>> verbosity = get_from_config_or_commandline(config, 'verbose', args,
... defaults)
"""
arg_keyword = '--' + keyword.replace('_', '-')
# load from config first
value = getattr(config, keyword, defaults[arg_keyword])
# override it if provided by command line arguments
if args[arg_keyword] != defaults[arg_keyword]:
value = args[arg_keyword]
if not default_is_valid and value == defaults[arg_keyword]:
raise ValueError(
"The value provided for {} is not valid.".format(keyword))
return value
...@@ -25,13 +25,18 @@ rn.seed(12345) ...@@ -25,13 +25,18 @@ rn.seed(12345)
# non-reproducible results. # non-reproducible results.
# For further details, see: # For further details, see:
# https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res # https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, session_config = tf.ConfigProto(intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1) inter_op_parallelism_threads=1)
# The below tf.set_random_seed() will make random number generation # The below tf.set_random_seed() will make random number generation
# in the TensorFlow backend have a well-defined initial state. # in the TensorFlow backend have a well-defined initial state.
# For further details, see: # For further details, see:
# https://www.tensorflow.org/api_docs/python/tf/set_random_seed # https://www.tensorflow.org/api_docs/python/tf/set_random_seed
tf.set_random_seed(1234) tf_random_seed = 1234
# sess = tf.Session(graph=tf.get_default_graph(), config=session_conf) tf.set_random_seed(tf_random_seed)
# sess = tf.Session(graph=tf.get_default_graph(), config=session_config)
# keras.backend.set_session(sess) # keras.backend.set_session(sess)
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(session_config=session_config)
run_config = run_config.replace(tf_random_seed=tf_random_seed)
...@@ -47,13 +47,13 @@ setup( ...@@ -47,13 +47,13 @@ setup(
# scripts should be declared using this entry: # scripts should be declared using this entry:
'console_scripts': [ 'console_scripts': [
'compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main', 'bob_tf_compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main',
'train.py = bob.learn.tensorflow.script.train:main', 'bob_tf_db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:main',
'bob_db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:main', 'bob_tf_load_and_debug.py = bob.learn.tensorflow.script.load_and_debug:main',
'load_and_debug.py = bob.learn.tensorflow.script.load_and_debug:main', 'bob_tf_lfw_db_to_tfrecords.py = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main',
'lfw_db_to_tfrecords.py = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main',
'bob_tf_train_generic = bob.learn.tensorflow.script.train_generic:main', 'bob_tf_train_generic = bob.learn.tensorflow.script.train_generic:main',
'bob_tf_eval_generic = bob.learn.tensorflow.script.eval_generic:main', 'bob_tf_eval_generic = bob.learn.tensorflow.script.eval_generic:main',
'bob_tf_predict_generic = bob.learn.tensorflow.script.predict_generic:main',
], ],
}, },
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment