From 474802415ca36af3e4200ee70a887f10f0b1b8af Mon Sep 17 00:00:00 2001 From: Tiago Pereira <tiago.pereira@partner.samsung.com> Date: Sun, 16 Apr 2017 15:50:25 -0700 Subject: [PATCH] Reformulating the Trainer --- bob/learn/tensorflow/datashuffler/Base.py | 98 ++++-- bob/learn/tensorflow/network/__init__.py | 1 + bob/learn/tensorflow/test/test_cnn.py | 42 +-- .../test/test_cnn_pretrained_model.py | 96 +++--- bob/learn/tensorflow/test/test_cnn_scratch.py | 2 +- bob/learn/tensorflow/trainers/Trainer.py | 301 ++++++++---------- 6 files changed, 271 insertions(+), 269 deletions(-) diff --git a/bob/learn/tensorflow/datashuffler/Base.py b/bob/learn/tensorflow/datashuffler/Base.py index 47d18f00..a4aab2cb 100644 --- a/bob/learn/tensorflow/datashuffler/Base.py +++ b/bob/learn/tensorflow/datashuffler/Base.py @@ -39,16 +39,26 @@ class Base(object): normalizer: The algorithm used for feature scaling. Look :py:class:`bob.learn.tensorflow.datashuffler.ScaleFactor`, :py:class:`bob.learn.tensorflow.datashuffler.Linear` and :py:class:`bob.learn.tensorflow.datashuffler.MeanOffset` + + prefetch: + Do prefetch? + + prefetch_capacity: + """ def __init__(self, data, labels, - input_shape, + input_shape=[None, 28, 28, 1], input_dtype="float64", - batch_size=1, + batch_size=32, seed=10, data_augmentation=None, - normalizer=Linear()): + normalizer=Linear(), + prefetch=False, + prefetch_capacity=10): + + # Setting the seed for the pseudo random number generator self.seed = seed numpy.random.seed(seed) @@ -58,10 +68,9 @@ class Base(object): # TODO: Check if the bacth size is higher than the input data self.batch_size = batch_size + # Preparing the inputs self.data = data - self.shape = tuple([batch_size] + input_shape) self.input_shape = tuple(input_shape) - self.labels = labels self.possible_labels = list(set(self.labels)) @@ -72,43 +81,72 @@ class Base(object): self.indexes = numpy.array(range(self.n_samples)) numpy.random.shuffle(self.indexes) - self.data_placeholder = None - self.label_placeholder = None - + # Use data data augmentation? self.data_augmentation = data_augmentation - self.deployment_shape = [None] + list(input_shape) - def set_placeholders(self, data, label): - self.data_placeholder = data - self.label_placeholder = label + # Preparing placeholders + self.data_ph = None + self.label_ph = None + # Prefetch variables + self.prefetch = prefetch + self.data_ph_from_queue = None + self.label_ph_from_queue = None - def get_batch(self): + def create_placeholders(self): """ - Shuffle dataset and get a random batch. + Create place holder instances + + :return: """ - raise NotImplementedError("Method not implemented in this level. You should use one of the derived classes.") + with tf.name_scope("Input"): + + self.data_ph = tf.placeholder(tf.float32, shape=self.input_shape, name="data") + self.label_ph = tf.placeholder(tf.int64, shape=[None], name="label") + + # If prefetch, setup the queue to feed data + if self.prefetch: + queue = tf.FIFOQueue(capacity=self.prefetch_capacity, + dtypes=[tf.float32, tf.int64], + shapes=[self.input_shape[1:], []]) - def get_placeholders(self, name=""): + # Fetching the place holders from the queue + self.enqueue_op = queue.enqueue_many([self.data_ph, self.label_ph]) + self.data_ph_from_queue, self.label_ph_from_queue = queue.dequeue_many(self.batch_size) + + else: + self.data_ph_from_queue = self.data_ph + self.label_ph_from_queue = self.label_ph + + def __call__(self, element, from_queue=False): """ - Returns a place holder with the size of your batch + Return the necessary placeholder + """ - if self.data_placeholder is None: - self.data_placeholder = tf.placeholder(tf.float32, shape=self.shape, name=name) + if not element in ["data", "label"]: + raise ValueError("Value '{0}' invalid. Options available are {1}".format(element, self.placeholder_options)) + + # If None, create the placeholders from scratch + if self.data_ph is None: + self.create_placeholders() - if self.label_placeholder is None: - self.label_placeholder = tf.placeholder(tf.int64, shape=self.shape[0]) + if element == "data": + if from_queue: + return self.data_ph_from_queue + else: + return self.data_ph - return [self.data_placeholder, self.label_placeholder] + else: + if from_queue: + return self.label_ph_from_queue + else: + return self.label_ph - def get_placeholders_forprefetch(self, name=""): + def get_batch(self): """ - Returns a place holder with the size of your batch + Shuffle dataset and get a random batch. """ - if self.data_placeholder is None: - self.data_placeholder = tf.placeholder(tf.float32, shape=tuple([None] + list(self.shape[1:])), name=name) - self.label_placeholder = tf.placeholder(tf.int64, shape=[None, ]) - return [self.data_placeholder, self.label_placeholder] + raise NotImplementedError("Method not implemented in this level. You should use one of the derived classes.") def bob2skimage(self, bob_image): """ @@ -167,10 +205,6 @@ class Base(object): else: return data - def reshape_for_deploy(self, data): - shape = tuple([1] + list(data.shape)) - return numpy.reshape(data, shape) - def normalize_sample(self, x): """ Normalize the sample. diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py index e83fb345..0a88bc3c 100644 --- a/bob/learn/tensorflow/network/__init__.py +++ b/bob/learn/tensorflow/network/__init__.py @@ -10,6 +10,7 @@ from .VGG16 import VGG16 from .VGG16_mod import VGG16_mod from .SimpleAudio import SimpleAudio from .Embedding import Embedding +#from .Input import Input # gets sphinx autodoc done right - don't remove it def __appropriate__(*args): diff --git a/bob/learn/tensorflow/test/test_cnn.py b/bob/learn/tensorflow/test/test_cnn.py index 59e62238..4122a547 100644 --- a/bob/learn/tensorflow/test/test_cnn.py +++ b/bob/learn/tensorflow/test/test_cnn.py @@ -5,7 +5,7 @@ import numpy from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, ImageAugmentation, ScaleFactor -from bob.learn.tensorflow.network import Chopra, SequenceNetwork +from bob.learn.tensorflow.network import Chopra from bob.learn.tensorflow.loss import BaseLoss, ContrastiveLoss, TripletLoss from bob.learn.tensorflow.trainers import Trainer, SiameseTrainer, TripletTrainer, constant from .test_cnn_scratch import validate_network @@ -23,7 +23,7 @@ Some unit tests for the datashuffler """ batch_size = 32 -validation_batch_size = 400 +validation_batch_size = 32 iterations = 300 seed = 10 @@ -77,6 +77,7 @@ def dummy_experiment(data_s, architecture): def test_cnn_trainer(): + # Loading data train_data, train_labels, validation_data, validation_labels = load_mnist() train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1)) validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1)) @@ -84,40 +85,45 @@ def test_cnn_trainer(): # Creating datashufflers data_augmentation = ImageAugmentation() train_data_shuffler = Memory(train_data, train_labels, - input_shape=[28, 28, 1], + input_shape=[None, 28, 28, 1], batch_size=batch_size, data_augmentation=data_augmentation, normalizer=ScaleFactor()) + validation_data_shuffler = Memory(validation_data, validation_labels, + input_shape=[None, 28, 28, 1], + batch_size=batch_size, + data_augmentation=data_augmentation, + normalizer=ScaleFactor()) + directory = "./temp/cnn" # Loss for the softmax loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean) - inputs = {} - inputs['data'] = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name="train_data") - inputs['label'] = tf.placeholder(tf.int64, shape=[None], name="train_label") - # Preparing the architecture architecture = Chopra(seed=seed, fc1_output=10) - graph = architecture(inputs['data']) - embedding = Embedding(inputs['data'], graph) + input_pl = train_data_shuffler("data", from_queue=True) + graph = architecture(input_pl) + embedding = Embedding(train_data_shuffler("data", from_queue=False), graph) # One graph trainer - trainer = Trainer(inputs=inputs, - graph=graph, - loss=loss, + trainer = Trainer(train_data_shuffler, iterations=iterations, analizer=None, - prefetch=False, - learning_rate=constant(0.01, name="regular_lr"), - optimizer=tf.train.GradientDescentOptimizer(0.01), temp_dir=directory ) - trainer.train(train_data_shuffler) + trainer.create_network_from_scratch(graph=graph, + loss=loss, + learning_rate=constant(0.01, name="regular_lr"), + optimizer=tf.train.GradientDescentOptimizer(0.01), + ) + trainer.train() + #trainer.train(validation_data_shuffler) + + # Using embedding to compute the accuracy accuracy = validate_network(embedding, validation_data, validation_labels) - # At least 80% of accuracy assert accuracy > 80. shutil.rmtree(directory) @@ -165,8 +171,6 @@ def test_siamesecnn_trainer(): optimizer=tf.train.AdamOptimizer(name="adam_siamese"), temp_dir=directory ) - - import ipdb; ipdb.set_trace(); trainer.train(train_data_shuffler) eer = dummy_experiment(validation_data_shuffler, architecture) diff --git a/bob/learn/tensorflow/test/test_cnn_pretrained_model.py b/bob/learn/tensorflow/test/test_cnn_pretrained_model.py index 62ee4554..645e2ae8 100644 --- a/bob/learn/tensorflow/test/test_cnn_pretrained_model.py +++ b/bob/learn/tensorflow/test/test_cnn_pretrained_model.py @@ -10,9 +10,11 @@ from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, Triplet from bob.learn.tensorflow.loss import BaseLoss, TripletLoss, ContrastiveLoss from bob.learn.tensorflow.trainers import Trainer, constant, TripletTrainer, SiameseTrainer from bob.learn.tensorflow.utils import load_mnist -from bob.learn.tensorflow.network import SequenceNetwork from bob.learn.tensorflow.layers import Conv2D, FullyConnected +from bob.learn.tensorflow.network import Embedding from .test_cnn import dummy_experiment +from .test_cnn_scratch import validate_network + import tensorflow as tf import shutil @@ -23,46 +25,38 @@ Some unit tests that create networks on the fly and load variables batch_size = 16 validation_batch_size = 400 -iterations = 50 +iterations =300 seed = 10 -def scratch_network(): +def scratch_network(input_pl): # Creating a random network - scratch = SequenceNetwork(default_feature_layer="fc1") - scratch.add(Conv2D(name="conv1", kernel_size=3, - filters=10, - activation=tf.nn.tanh, - batch_norm=False)) - scratch.add(FullyConnected(name="fc1", output_dim=10, - activation=None, - batch_norm=False - )) - - return scratch + slim = tf.contrib.slim + initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=10) -def validate_network(validation_data, validation_labels, network): - # Testing - validation_data_shuffler = Memory(validation_data, validation_labels, - input_shape=[28, 28, 1], - batch_size=validation_batch_size) - - [data, labels] = validation_data_shuffler.get_batch() - predictions = network.predict(data) - accuracy = 100. * numpy.sum(predictions == labels) / predictions.shape[0] + scratch = slim.conv2d(input_pl, 10, 3, activation_fn=tf.nn.tanh, + stride=1, + weights_initializer=initializer, + scope='conv1') + scratch = slim.flatten(scratch, scope='flatten1') + scratch = slim.fully_connected(scratch, 10, + weights_initializer=initializer, + activation_fn=None, + scope='fc1') - return accuracy + return scratch def test_cnn_pretrained(): + # Preparing input data train_data, train_labels, validation_data, validation_labels = load_mnist() train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1)) # Creating datashufflers data_augmentation = ImageAugmentation() train_data_shuffler = Memory(train_data, train_labels, - input_shape=[28, 28, 1], + input_shape=[None, 28, 28, 1], batch_size=batch_size, data_augmentation=data_augmentation) validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1)) @@ -71,51 +65,55 @@ def test_cnn_pretrained(): directory2 = "./temp/cnn2" # Creating a random network - scratch = scratch_network() + input_pl = train_data_shuffler("data", from_queue=True) + graph = scratch_network(input_pl) + embedding = Embedding(train_data_shuffler("data", from_queue=False), graph) # Loss for the softmax loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean) # One graph trainer - trainer = Trainer(architecture=scratch, - loss=loss, + # One graph trainer + trainer = Trainer(train_data_shuffler, iterations=iterations, analizer=None, - prefetch=False, - learning_rate=constant(0.05, name="regular_lr"), - optimizer=tf.train.AdamOptimizer(name="adam_pretrained_model"), temp_dir=directory ) - - trainer.train(train_data_shuffler) - accuracy = validate_network(validation_data, validation_labels, scratch) - assert accuracy > 85 - - del scratch + trainer.create_network_from_scratch(graph=graph, + loss=loss, + learning_rate=constant(0.01, name="regular_lr"), + optimizer=tf.train.GradientDescentOptimizer(0.01), + ) + trainer.train() + accuracy = validate_network(embedding, validation_data, validation_labels) + assert accuracy > 80 + + del graph del loss del trainer - # Training the network using a pre trained model loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean, name="loss") - scratch = scratch_network() - trainer = Trainer(architecture=scratch, - loss=loss, - iterations=iterations + 200, + graph = scratch_network(input_pl) + + # One graph trainer + trainer = Trainer(train_data_shuffler, + iterations=iterations, analizer=None, - prefetch=False, - learning_rate=None, - temp_dir=directory2, - model_from_file=os.path.join(directory, "model.ckp") + temp_dir=directory ) + trainer.create_network_from_file(os.path.join(directory, "model.ckp")) - trainer.train(train_data_shuffler) + import ipdb; + ipdb.set_trace() - accuracy = validate_network(validation_data, validation_labels, scratch) + trainer.train() + + accuracy = validate_network(embedding, validation_data, validation_labels) assert accuracy > 90 shutil.rmtree(directory) shutil.rmtree(directory2) - del scratch + del graph del loss del trainer diff --git a/bob/learn/tensorflow/test/test_cnn_scratch.py b/bob/learn/tensorflow/test/test_cnn_scratch.py index 2fa05da7..28e98ad7 100644 --- a/bob/learn/tensorflow/test/test_cnn_scratch.py +++ b/bob/learn/tensorflow/test/test_cnn_scratch.py @@ -48,7 +48,7 @@ def scratch_network(): def validate_network(embedding, validation_data, validation_labels): # Testing validation_data_shuffler = Memory(validation_data, validation_labels, - input_shape=[28, 28, 1], + input_shape=[None, 28, 28, 1], batch_size=validation_batch_size, normalizer=ScaleFactor()) diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py index 8e1ca9a9..4c3191c5 100644 --- a/bob/learn/tensorflow/trainers/Trainer.py +++ b/bob/learn/tensorflow/trainers/Trainer.py @@ -69,119 +69,138 @@ class Trainer(object): """ def __init__(self, - inputs, - graph, - optimizer=tf.train.AdamOptimizer(), - use_gpu=False, - loss=None, - temp_dir="cnn", - - # Learning rate - learning_rate=None, + train_data_shuffler, ###### training options ########## - convergence_threshold=0.01, iterations=5000, snapshot=500, validation_snapshot=100, - prefetch=False, ## Analizer analizer=SoftmaxAnalizer(), - ### Pretrained model - model_from_file="", + # Temporatu dir + temp_dir="cnn", verbosity_level=2): - self.inputs = inputs - self.graph = graph - self.loss = loss - - if not isinstance(self.graph, tf.Tensor): - raise ValueError("Expected a tf.Tensor as input for the keywork `graph`") - - self.predictor = self.loss(self.graph, inputs['label']) - - self.optimizer_class = optimizer - self.use_gpu = use_gpu + self.train_data_shuffler = train_data_shuffler self.temp_dir = temp_dir - if learning_rate is None and model_from_file == "": - self.learning_rate = constant() - else: - self.learning_rate = learning_rate - self.iterations = iterations self.snapshot = snapshot self.validation_snapshot = validation_snapshot - self.convergence_threshold = convergence_threshold - self.prefetch = prefetch # Training variables used in the fit - self.optimizer = None - self.training_graph = None - self.train_data_shuffler = None self.summaries_train = None self.train_summary_writter = None self.thread_pool = None # Validation data - self.validation_graph = None self.validation_summary_writter = None # Analizer self.analizer = analizer - - self.thread_pool = None - self.enqueue_op = None self.global_step = None - self.model_from_file = model_from_file self.session = None + self.graph = None + self.loss = None + self.predictor = None + self.optimizer_class = None + self.learning_rate = None + # Training variables used in the fit + self.optimizer = None + self.data_ph = None + self.label_ph = None + self.saver = None + bob.core.log.set_verbosity_level(logger, verbosity_level) - def __del__(self): - tf.reset_default_graph() + # Creating the session + self.session = Session.instance(new=True).session + self.from_scratch = True - """ - def compute_graph(self, data_shuffler, prefetch=False, name="", training=True): - Computes the graph for the trainer. + def create_network_from_scratch(self, + graph, + optimizer=tf.train.AdamOptimizer(), + loss=None, - ** Parameters ** + # Learning rate + learning_rate=None, + ): + + self.saver = tf.train.Saver(var_list=tf.global_variables()) + + self.data_ph = self.train_data_shuffler("data") + self.label_ph = self.train_data_shuffler("label") + self.graph = graph + self.loss = loss + self.predictor = self.loss(self.graph, self.train_data_shuffler("label", from_queue=True)) - data_shuffler: Data shuffler - prefetch: Uses prefetch - name: Name of the graph - training: Is it a training graph? + self.optimizer_class = optimizer + self.learning_rate = learning_rate - # Defining place holders - if prefetch: - [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name) + # TODO: find an elegant way to provide this as a parameter of the trainer + self.global_step = tf.Variable(0, trainable=False, name="global_step") + tf.add_to_collection("global_step", self.global_step) - # Defining a placeholder queue for prefetching - queue = tf.FIFOQueue(capacity=10, - dtypes=[tf.float32, tf.int64], - shapes=[placeholder_data.get_shape().as_list()[1:], []]) + tf.add_to_collection("graph", self.graph) + tf.add_to_collection("predictor", self.predictor) - # Fetching the place holders from the queue - self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels]) - feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size) + tf.add_to_collection("data_ph", self.data_ph) + tf.add_to_collection("label_ph", self.label_ph) - # Creating the architecture for train and validation - if not isinstance(self.architecture, SequenceNetwork): - raise ValueError("The variable `architecture` must be an instance of " - "`bob.learn.tensorflow.network.SequenceNetwork`") - else: - [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name) + # Preparing the optimizer + self.optimizer_class._learning_rate = self.learning_rate + self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step) + tf.add_to_collection("optimizer", self.optimizer) + tf.add_to_collection("learning_rate", self.learning_rate) - # Creating graphs and defining the loss - network_graph = self.architecture.compute_graph(feature_batch, training=training) - graph = self.loss(network_graph, label_batch) + self.summaries_train = self.create_general_summary() + tf.add_to_collection("summaries_train", self.summaries_train) - return graph - """ + # Creating the variables + tf.global_variables_initializer().run(session=self.session) + + def create_network_from_file(self, model_from_file): + """ + Bootstrap all the necessary data from file + + ** Parameters ** + session: Tensorflow session + train_data_shuffler: Data shuffler for training + validation_data_shuffler: Data shuffler for validation + + + """ + #saver = self.architecture.load(self.model_from_file, clear_devices=False) + self.saver = tf.train.import_meta_graph(model_from_file + ".meta") + self.saver.restore(self.session, model_from_file) + + # Loading training graph + self.data_ph = tf.get_collection("data_ph") + self.label_ph = tf.get_collection("label_ph") + + self.graph = tf.get_collection("graph")[0] + self.predictor = tf.get_collection("predictor")[0] + + # Loding other elements + self.optimizer = tf.get_collection("optimizer")[0] + self.learning_rate = tf.get_collection("learning_rate")[0] + self.summaries_train = tf.get_collection("summaries_train")[0] + self.global_step = tf.get_collection("global_step")[0] + self.from_scratch = False + + # Creating the variables + tf.global_variables_initializer().run(session=self.session) + import ipdb; ipdb.set_trace() + x=0 + + + def __del__(self): + tf.reset_default_graph() def get_feed_dict(self, data_shuffler): """ @@ -193,8 +212,8 @@ class Trainer(object): """ [data, labels] = data_shuffler.get_batch() - feed_dict = {self.inputs['data']: data, - self.inputs['label']: labels} + feed_dict = {self.data_ph: data, + self.label_ph: labels} return feed_dict def fit(self, step): @@ -207,7 +226,7 @@ class Trainer(object): """ - if self.prefetch: + if self.train_data_shuffler.prefetch: _, l, lr, summary = self.session.run([self.optimizer, self.predictor, self.learning_rate, self.summaries_train]) else: @@ -218,6 +237,27 @@ class Trainer(object): logger.info("Loss training set step={0} = {1}".format(step, l)) self.train_summary_writter.add_summary(summary, step) + def compute_validation(self, data_shuffler, step): + """ + Computes the loss in the validation set + + ** Parameters ** + session: Tensorflow session + data_shuffler: The data shuffler to be used + step: Iteration number + + """ + pass + # Opening a new session for validation + #feed_dict = self.get_feed_dict(data_shuffler) + #l, summary = self.session.run(self.predictor, self.summaries_train, feed_dict=feed_dict) + #train_summary_writter.add_summary(summary, step) + + + #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))] + #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step) + #logger.info("Loss VALIDATION set step={0} = {1}".format(step, l)) + def create_general_summary(self): """ Creates a simple tensorboard summary with the value of the loss and learning rate @@ -228,14 +268,13 @@ class Trainer(object): tf.summary.scalar('lr', self.learning_rate) return tf.summary.merge_all() - """ def start_thread(self): - + """ Start pool of threads for pre-fetching **Parameters** session: Tensorflow session - + """ threads = [] for n in range(3): @@ -246,72 +285,22 @@ class Trainer(object): return threads def load_and_enqueue(self): - + """ Injecting data in the place holder queue **Parameters** session: Tensorflow session - + """ while not self.thread_pool.should_stop(): [train_data, train_labels] = self.train_data_shuffler.get_batch() - [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders() - feed_dict = {train_placeholder_data: train_data, - train_placeholder_labels: train_labels} + feed_dict = {self.data_ph: train_data, + self.label_ph: train_labels} - self.session.run(self.enqueue_op, feed_dict=feed_dict) + self.session.run(self.inputs.enqueue_op, feed_dict=feed_dict) - """ - - def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler): - """ - Bootstrap all the necessary data from file - - ** Parameters ** - session: Tensorflow session - train_data_shuffler: Data shuffler for training - validation_data_shuffler: Data shuffler for validation - - - """ - saver = self.architecture.load(self.model_from_file, clear_devices=False) - - # Loading training graph - self.training_graph = tf.get_collection("training_graph")[0] - - # Loding other elements - self.optimizer = tf.get_collection("optimizer")[0] - self.learning_rate = tf.get_collection("learning_rate")[0] - self.summaries_train = tf.get_collection("summaries_train")[0] - self.global_step = tf.get_collection("global_step")[0] - - if validation_data_shuffler is not None: - self.validation_graph = tf.get_collection("validation_graph")[0] - - self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler) - - return saver - - def compute_validation(self, data_shuffler, step): - """ - Computes the loss in the validation set - - ** Parameters ** - session: Tensorflow session - data_shuffler: The data shuffler to be used - step: Iteration number - - """ - # Opening a new session for validation - feed_dict = self.get_feed_dict(data_shuffler) - l = self.session.run(self.predictor, feed_dict=feed_dict) - - summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))] - self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step) - logger.info("Loss VALIDATION set step={0} = {1}".format(step, l)) - - def train(self, train_data_shuffler, validation_data_shuffler=None): + def train(self, validation_data_shuffler=None): """ Train the network: @@ -323,56 +312,32 @@ class Trainer(object): # Creating directory bob.io.base.create_directories_safe(self.temp_dir) - self.train_data_shuffler = train_data_shuffler logger.info("Initializing !!") - self.session = Session.instance(new=True).session # Loading a pretrained model - if self.model_from_file != "": - logger.info("Loading pretrained model from {0}".format(self.model_from_file)) - saver = self.bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler) - - start_step = self.global_step.eval(session=self.session) - - else: + if self.from_scratch: start_step = 0 - - # TODO: find an elegant way to provide this as a parameter of the trainer - self.global_step = tf.Variable(0, trainable=False, name="global_step") - tf.add_to_collection("global_step", self.global_step) - - # Preparing the optimizer - self.optimizer_class._learning_rate = self.learning_rate - self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step) - tf.add_to_collection("optimizer", self.optimizer) - tf.add_to_collection("learning_rate", self.learning_rate) - - self.summaries_train = self.create_general_summary() - tf.add_to_collection("summaries_train", self.summaries_train) - - # Train summary - tf.global_variables_initializer().run(session=self.session) - - # Original tensorflow saver object - saver = tf.train.Saver(var_list=tf.global_variables()) + else: + start_step = self.global_step.eval(session=self.session) #if isinstance(train_data_shuffler, OnlineSampling): # train_data_shuffler.set_feature_extractor(self.architecture, session=self.session) # Start a thread to enqueue data asynchronously, and hide I/O latency. - #if self.prefetch: - # self.thread_pool = tf.train.Coordinator() - # tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session) - # threads = self.start_thread() + if self.train_data_shuffler.prefetch: + self.thread_pool = tf.train.Coordinator() + tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session) + threads = self.start_thread() # TENSOR BOARD SUMMARY self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph) if validation_data_shuffler is not None: self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'), self.session.graph) - + # Loop for for step in range(start_step, self.iterations): + # Run fit in the graph start = time.time() self.fit(step) end = time.time() @@ -391,7 +356,7 @@ class Trainer(object): if step % self.snapshot == 0: logger.info("Taking snapshot") path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step)) - saver.save(self.session, path) + self.saver.save(self.session, path) #self.architecture.save(saver, path) logger.info("Training finally finished") @@ -402,9 +367,9 @@ class Trainer(object): # Saving the final network path = os.path.join(self.temp_dir, 'model.ckp') - saver.save(self.session, path) + self.saver.save(self.session, path) - if self.prefetch: + if self.train_data_shuffler.prefetch: # now they should definetely stop self.thread_pool.request_stop() #self.thread_pool.join(threads) -- GitLab