From 52023e4ae7521824fb102cab2c77bb6db6528e24 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Tue, 10 Oct 2017 12:14:36 +0200 Subject: [PATCH] Move code from bob.dap to here --- .../tensorflow/script/db_to_tfrecords.py | 21 +-- bob/learn/tensorflow/script/eval_generic.py | 96 ++++++++++++ bob/learn/tensorflow/script/train_generic.py | 98 ++++++++++++ bob/learn/tensorflow/utils/eval.py | 110 ++++++++++++++ bob/learn/tensorflow/utils/hooks.py | 35 +++++ bob/learn/tensorflow/utils/reproducible.py | 37 +++++ bob/learn/tensorflow/utils/sequences.py | 143 ++++++++++++++++++ bob/learn/tensorflow/utils/tfrecords.py | 25 +++ 8 files changed, 555 insertions(+), 10 deletions(-) create mode 100644 bob/learn/tensorflow/script/eval_generic.py create mode 100644 bob/learn/tensorflow/script/train_generic.py create mode 100644 bob/learn/tensorflow/utils/eval.py create mode 100644 bob/learn/tensorflow/utils/hooks.py create mode 100644 bob/learn/tensorflow/utils/reproducible.py create mode 100644 bob/learn/tensorflow/utils/sequences.py create mode 100644 bob/learn/tensorflow/utils/tfrecords.py diff --git a/bob/learn/tensorflow/script/db_to_tfrecords.py b/bob/learn/tensorflow/script/db_to_tfrecords.py index 10cc5c19..9f9c2b46 100644 --- a/bob/learn/tensorflow/script/db_to_tfrecords.py +++ b/bob/learn/tensorflow/script/db_to_tfrecords.py @@ -91,20 +91,21 @@ from bob.core.log import setup, set_verbosity_level logger = setup(__name__) -def _bytes_feature(value): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) +def bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) -def _int64_feature(value): - return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) +def int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) -def write_a_sample(writer, data, label): - feature = {'train/data': _bytes_feature(data.tostring()), - 'train/label': _int64_feature(label)} +def write_a_sample(writer, data, label, feature=None): + if feature is None: + feature = {'train/data': bytes_feature(data.tostring()), + 'train/label': int64_feature(label)} - example = tf.train.Example(features=tf.train.Features(feature=feature)) - writer.write(example.SerializeToString()) + example = tf.train.Example(features=tf.train.Features(feature=feature)) + writer.write(example.SerializeToString()) def main(argv=None): @@ -167,4 +168,4 @@ def main(argv=None): if __name__ == '__main__': - main() + main() diff --git a/bob/learn/tensorflow/script/eval_generic.py b/bob/learn/tensorflow/script/eval_generic.py new file mode 100644 index 00000000..23127563 --- /dev/null +++ b/bob/learn/tensorflow/script/eval_generic.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python + +"""Trains the VGG-audio network on the AVspoof database. + +Usage: + %(prog)s [options] <checkpoint_dir> <eval_tfrecords>... + +Options: + -h --help Show this screen. + --eval-dir PATH [default: /idiap/user/amohammadi/avspoof/specgram/avspoof-simple-cnn-eval] + --input-shape N [default: (50, 1024, 1)] + --batch-size N [default: 50] + --run-once Evaluate the model once only. + --eval-interval-secs N Interval to evaluations. [default: 300] + --num-examples N Number of examples to run. [default: None] Provide + ``None`` to consider all examples. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +# for bob imports to work properly: +import pkg_resources +import os +import time +from functools import partial +import tensorflow as tf +from docopt import docopt +from bob.io.base import create_directories_safe +from bob.dap.voice.architectures.simple_cnn import architecture +from bob.dap.base.database.tfrecords import example_parser +from bob.dap.base.util.eval import get_global_step, eval_once + + +def main(argv=None): + arguments = docopt(__doc__, argv=argv) + print(arguments) + input_shape = eval(arguments['--input-shape']) + tfrecord_filenames = arguments['<eval_tfrecords>'] + eval_dir = arguments['--eval-dir'] + batch_size = eval(arguments['--batch-size']) + run_once = arguments['--run-once'] + eval_interval_secs = eval(arguments['--eval-interval-secs']) + checkpoint_dir = arguments['<checkpoint_dir>'] + num_examples = eval(arguments['--num-examples']) + + create_directories_safe(eval_dir) + with tf.Graph().as_default() as g: + + # Get images and labels + with tf.name_scope('input'): + dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames) + feature = {'train/data': tf.FixedLenFeature([], tf.string), + 'train/label': tf.FixedLenFeature([], tf.int64)} + my_example_parser = partial( + example_parser, feature=feature, data_shape=input_shape) + dataset = dataset.map( + my_example_parser, num_threads=1, output_buffer_size=batch_size) + dataset = dataset.batch(batch_size) + images, labels = dataset.make_one_shot_iterator().get_next() + + # Build a Graph that computes the logits predictions from the + # inference model. + logits = architecture(images, mode=tf.estimator.ModeKeys.EVAL) + + # Calculate predictions. + top_k_op = tf.nn.in_top_k(logits, labels, 1) + + saver = tf.train.Saver() + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.summary.merge_all() + summary_writer = tf.summary.FileWriter(eval_dir, g) + evaluated_file = os.path.join(eval_dir, 'evaluated') + + while True: + evaluated_steps = [] + if os.path.exists(evaluated_file): + with open(evaluated_file) as f: + evaluated_steps = f.read().split() + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) + if ckpt and ckpt.model_checkpoint_path: + for path in ckpt.all_model_checkpoint_paths: + global_step = get_global_step(path) + if global_step not in evaluated_steps: + ret_val = eval_once(saver, summary_writer, top_k_op, summary_op, + path, global_step, num_examples, + batch_size) + if ret_val == 0: + with open(evaluated_file, 'a') as f: + f.write(global_step + '\n') + if run_once: + break + time.sleep(eval_interval_secs) + + +if __name__ == '__main__': + main() diff --git a/bob/learn/tensorflow/script/train_generic.py b/bob/learn/tensorflow/script/train_generic.py new file mode 100644 index 00000000..246b1c12 --- /dev/null +++ b/bob/learn/tensorflow/script/train_generic.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python + +"""Trains the VGG-audio network on the AVspoof database. + +Usage: + %(prog)s [options] <train_tfrecords>... + +Options: + -h --help Show this screen. + --save-dir PATH [default: /idiap/user/amohammadi/avspoof/specgram/avspoof-simple-cnn-train] + --input-shape N [default: (50, 1024, 1)] + --epochs N [default: None] + --batch-size N [default: 32] + --capacity-samples N The capacity of the queue [default: 10**4/2]. + --learning-rate N The learning rate [default: 0.00001]. + --log-frequency N How often to log results to the console. + [default: 100] +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +# for bob imports to work properly: +import pkg_resources +# for creating reproducible nets +from bob.dap.base.sandbox.reproducible import session_conf + +import tensorflow as tf +from docopt import docopt +from bob.io.base import create_directories_safe +from bob.dap.voice.architectures.simple_cnn import architecture +from bob.dap.base.database.tfrecords import read_and_decode +from bob.dap.base.util.hooks import LoggerHook + + +def main(argv=None): + arguments = docopt(__doc__, argv=argv) + print(arguments) + input_shape = eval(arguments['--input-shape']) + tfrecord_filenames = arguments['<train_tfrecords>'] + save_dir = arguments['--save-dir'] + epochs = eval(arguments['--epochs']) + batch_size = eval(arguments['--batch-size']) + capacity_samples = eval(arguments['--capacity-samples']) + learning_rate = eval(arguments['--learning-rate']) + log_frequency = eval(arguments['--log-frequency']) + + create_directories_safe(save_dir) + with tf.Graph().as_default(): + global_step = tf.contrib.framework.get_or_create_global_step() + + # Get images and labels + with tf.name_scope('input'): + filename_queue = tf.train.string_input_producer( + tfrecord_filenames, num_epochs=epochs, name="tfrecord_filenames") + + image, label = read_and_decode(filename_queue, input_shape) + images, labels = tf.train.shuffle_batch( + [image, label], batch_size=batch_size, + capacity=capacity_samples // batch_size, + min_after_dequeue=int(capacity_samples // batch_size // 2), + num_threads=1, name="shuffle_batch") + + # Build a Graph that computes the logits predictions from the + # inference model. + logits = architecture(images) + tf.add_to_collection('logits', logits) + + # Calculate loss. + predictor = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + loss = tf.reduce_mean(predictor) + tf.summary.scalar('loss', loss) + + # Build a Graph that trains the model with one batch of examples and + # updates the model parameters. + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + train_op = optimizer.minimize(loss, global_step=global_step) + + saver = tf.train.Saver(max_to_keep=10**5) + scaffold = tf.train.Scaffold(saver=saver) + with tf.train.MonitoredTrainingSession( + checkpoint_dir=save_dir, + scaffold=scaffold, + hooks=[ + tf.train.CheckpointSaverHook( + save_dir, save_secs=60 * 29, scaffold=scaffold), + tf.train.NanTensorHook(loss), + LoggerHook(loss, batch_size, log_frequency)], + config=session_conf, + save_checkpoint_secs=None, + save_summaries_steps=100, + ) as mon_sess: + while not mon_sess.should_stop(): + mon_sess.run(train_op) + + +if __name__ == '__main__': + main() diff --git a/bob/learn/tensorflow/utils/eval.py b/bob/learn/tensorflow/utils/eval.py new file mode 100644 index 00000000..a431f8d2 --- /dev/null +++ b/bob/learn/tensorflow/utils/eval.py @@ -0,0 +1,110 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import tensorflow as tf +from datetime import datetime + + +def get_global_step(path): + """Returns the global number associated with the model checkpoint path. The + checkpoint must have been saved with the + :any:`tf.train.MonitoredTrainingSession`. + + Parameters + ---------- + path : str + The path to model checkpoint, usually ckpt.model_checkpoint_path + + Returns + ------- + global_step : str + The global step number as a string. + """ + # Assuming model_checkpoint_path looks something like: + # /my-favorite-path/train/model.ckpt-0, + # extract global_step from it. + global_step = path.split('/')[-1].split('-')[-1] + return global_step + + +def eval_once(saver, summary_writer, prediction_op, summary_op, + model_checkpoint_path, global_step, num_examples, batch_size): + """Run Eval once. + + Parameters + ---------- + saver + Saver. + summary_writer + Summary writer. + prediction_op + Prediction operator. + summary_op + Summary operator. + model_checkpoint_path : str + Path to the model checkpoint. + global_step : str + The global step. + num_examples : int or None + The number of examples to try. + batch_size : int + The size of evaluation batch. + + This function requires the ``from __future__ import division`` import. + + Returns + ------- + int + 0 for success, anything else for fail. + + """ + with tf.Session() as sess: + sess.run(tf.local_variables_initializer()) + sess.run(tf.global_variables_initializer()) + + if model_checkpoint_path: + # Restores from checkpoint + saver.restore(sess, model_checkpoint_path) + else: + print('No checkpoint file found') + return -1 + + # Start the queue runners. + coord = tf.train.Coordinator() + try: + threads = [] + for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): + threads.extend(qr.create_threads(sess, coord=coord, daemon=True, + start=True)) + + if num_examples is None: + num_iter = float("inf") + else: + num_iter = int(math.ceil(num_examples / batch_size)) + true_count = 0 # Counts the number of correct predictions. + total_sample_count = 0 + step = 0 + while step < num_iter and not coord.should_stop(): + predictions = sess.run([prediction_op]) + true_count += np.sum(predictions) + total_sample_count += np.asarray(predictions).size + step += 1 + + # Compute precision @ 1. + precision = true_count / total_sample_count + print('%s: precision @ 1 = %.3f' % (datetime.now(), precision)) + + summary = tf.Summary() + summary.ParseFromString(sess.run(summary_op)) + summary.value.add(tag='Precision @ 1', simple_value=precision) + summary_writer.add_summary(summary, global_step) + return 0 + except Exception as e: # pylint: disable=broad-except + coord.request_stop(e) + return -1 + finally: + coord.request_stop() + coord.join(threads, stop_grace_period_secs=10) diff --git a/bob/learn/tensorflow/utils/hooks.py b/bob/learn/tensorflow/utils/hooks.py new file mode 100644 index 00000000..fe15d519 --- /dev/null +++ b/bob/learn/tensorflow/utils/hooks.py @@ -0,0 +1,35 @@ +import tensorflow as tf +import time +from datetime import datetime + + +class LoggerHook(tf.train.SessionRunHook): + """Logs loss and runtime.""" + + def __init__(self, loss, batch_size, log_frequency): + self.loss = loss + self.batch_size = batch_size + self.log_frequency = log_frequency + + def begin(self): + self._step = -1 + self._start_time = time.time() + + def before_run(self, run_context): + self._step += 1 + return tf.train.SessionRunArgs(self.loss) # Asks for loss value. + + def after_run(self, run_context, run_values): + if self._step % self.log_frequency == 0: + current_time = time.time() + duration = current_time - self._start_time + self._start_time = current_time + + loss_value = run_values.results + examples_per_sec = self.log_frequency * self.batch_size / duration + sec_per_batch = float(duration / self.log_frequency) + + format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' + 'sec/batch)') + print(format_str % (datetime.now(), self._step, loss_value, + examples_per_sec, sec_per_batch)) diff --git a/bob/learn/tensorflow/utils/reproducible.py b/bob/learn/tensorflow/utils/reproducible.py new file mode 100644 index 00000000..87c4ce87 --- /dev/null +++ b/bob/learn/tensorflow/utils/reproducible.py @@ -0,0 +1,37 @@ +import os +import numpy as np +import tensorflow as tf +import random as rn +from tensorflow.contrib import keras + +# reproducible networks +# The below is necessary in Python 3.2.3 onwards to +# have reproducible behavior for certain hash-based operations. +# See these references for further details: +# https://docs.python.org/3.4/using/cmdline.html#envvar-PYTHONHASHSEED +# https://github.com/fchollet/keras/issues/2280#issuecomment-306959926 +os.environ['PYTHONHASHSEED'] = '0' + +# The below is necessary for starting Numpy generated random numbers +# in a well-defined initial state. +np.random.seed(42) + +# The below is necessary for starting core Python generated random numbers +# in a well-defined state. +rn.seed(12345) + +# Force TensorFlow to use single thread. +# Multiple threads are a potential source of +# non-reproducible results. +# For further details, see: +# https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res +session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, + inter_op_parallelism_threads=1) + +# The below tf.set_random_seed() will make random number generation +# in the TensorFlow backend have a well-defined initial state. +# For further details, see: +# https://www.tensorflow.org/api_docs/python/tf/set_random_seed +tf.set_random_seed(1234) +# sess = tf.Session(graph=tf.get_default_graph(), config=session_conf) +# keras.backend.set_session(sess) diff --git a/bob/learn/tensorflow/utils/sequences.py b/bob/learn/tensorflow/utils/sequences.py new file mode 100644 index 00000000..fa16cb82 --- /dev/null +++ b/bob/learn/tensorflow/utils/sequences.py @@ -0,0 +1,143 @@ +from __future__ import division +import numpy +from keras.utils import Sequence +# documentation imports +from bob.dap.base.database import PadDatabase, PadFile +from bob.bio.base.preprocessor import Preprocessor + + +class PadSequence(Sequence): + """A data shuffler for bob.dap.base database interfaces. + + Attributes + ---------- + batch_size : int + The number of samples to return in every batch. + files : list of :any:`PadFile` + List of file objects for a particular group and protocol. + labels : list of bool + List of labels for the files. ``True`` if bona-fide, ``False`` if + attack. + preprocessor : :any:`Preprocessor` + The preprocessor to be used to load and process the data. + """ + + def __init__(self, files, labels, batch_size, preprocessor, + original_directory, original_extension): + super(PadSequence, self).__init__() + self.files = files + self.labels = labels + self.batch_size = int(batch_size) + self.preprocessor = preprocessor + self.original_directory = original_directory + self.original_extension = original_extension + + def __len__(self): + """Number of batch in the Sequence. + + Returns + ------- + int + The number of batches in the Sequence. + """ + return int(numpy.ceil(len(self.files) / self.batch_size)) + + def __getitem__(self, idx): + files = self.files[idx * self.batch_size:(idx + 1) * self.batch_size] + labels = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size] + return self.load_batch(files, labels) + + def load_batch(self, files, labels): + """Loads a batch of files and processes them. + + Parameters + ---------- + files : list of :any:`PadFile` + List of files to load. + labels : list of bool + List of labels corresponding to the files. + + Returns + ------- + tuple of :any:`numpy.array` + A tuple of (x, y): the data and their targets. + """ + data, targets = [], [] + for file_object, target in zip(files, labels): + loaded_data = self.preprocessor.read_original_data( + file_object, + self.original_directory, + self.original_extension) + preprocessed_data = self.preprocessor(loaded_data) + data.append(preprocessed_data) + targets.append(target) + return numpy.array(data), numpy.array(targets) + + def on_epoch_end(self): + pass + + +def shuffle_data(files, labels): + indexes = numpy.arange(len(files)) + numpy.random.shuffle(indexes) + return [files[i] for i in indexes], [labels[i] for i in indexes] + + +def get_pad_files_labels(database, groups): + """Returns the pad files and their labels. + + Parameters + ---------- + database : :any:`PadDatabase` + The database to be used. The database should have a proper + ``database.protocol`` attribute. + groups : str + The group to be used to return the data. One of ('world', 'dev', + 'eval'). 'world' means training data and 'dev' means validation data. + + Returns + ------- + tuple + A tuple of (files, labels) for that particular group and protocol. + """ + files = database.samples( + groups=groups, protocol=database.protocol) + labels = ((f.attack_type is None) for f in files) + labels = numpy.fromiter(labels, bool, len(files)) + return files, labels + + +def get_pad_sequences(database, preprocessor, batch_size, + groups=('world', 'dev', 'eval'), shuffle=False, + limit=None): + """Returns a list of :any:`Sequence` objects for the database. + + Parameters + ---------- + database : :any:`PadDatabase` + The database to be used. The database should have a proper + ``database.protocol`` attribute. + preprocessor : :any:`Preprocessor` + The preprocessor to be used to load and process the data. + batch_size : int + The number of samples to return in every batch. + groups : str + The group to be used to return the data. One of ('world', 'dev', + 'eval'). 'world' means training data and 'dev' means validation data. + + Returns + ------- + list of :any:`Sequence` + The requested sequences to be used. + """ + seqs = [] + for grp in groups: + files, labels = get_pad_files_labels(database, grp) + if shuffle: + files, labels = shuffle_data(files, labels) + if limit is not None: + files, labels = files[:limit], labels[:limit] + seqs.append(PadSequence(files, labels, batch_size, preprocessor, + database.original_directory, + database.original_extension)) + return seqs diff --git a/bob/learn/tensorflow/utils/tfrecords.py b/bob/learn/tensorflow/utils/tfrecords.py new file mode 100644 index 00000000..309575a2 --- /dev/null +++ b/bob/learn/tensorflow/utils/tfrecords.py @@ -0,0 +1,25 @@ +import tensorflow as tf + + +def example_parser(serialized_example, feature, data_shape): + """Parses a single tf.Example into image and label tensors.""" + # Decode the record read by the reader + features = tf.parse_single_example(serialized_example, features=feature) + # Convert the image data from string back to the numbers + image = tf.decode_raw(features['train/data'], tf.float32) + # Cast label data into int64 + label = tf.cast(features['train/label'], tf.int64) + # Reshape image data into the original shape + image = tf.reshape(image, data_shape) + return image, label + + +def read_and_decode(filename_queue, data_shape, feature=None): + + if feature is None: + feature = {'train/data': tf.FixedLenFeature([], tf.string), + 'train/label': tf.FixedLenFeature([], tf.int64)} + # Define a reader and read the next record + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) + return example_parser(serialized_example, feature, data_shape) -- GitLab