diff --git a/bob/learn/tensorflow/datashuffler/Base.py b/bob/learn/tensorflow/datashuffler/Base.py
index 47d18f005609d2d7f043e31f73db91554ddc7f2e..a4aab2cbb2ad48f44dbb17d087abbfc8fc34940e 100644
--- a/bob/learn/tensorflow/datashuffler/Base.py
+++ b/bob/learn/tensorflow/datashuffler/Base.py
@@ -39,16 +39,26 @@ class Base(object):
 
      normalizer:
        The algorithm used for feature scaling. Look :py:class:`bob.learn.tensorflow.datashuffler.ScaleFactor`, :py:class:`bob.learn.tensorflow.datashuffler.Linear` and :py:class:`bob.learn.tensorflow.datashuffler.MeanOffset`
+       
+     prefetch:
+        Do prefetch?
+        
+     prefetch_capacity:
+        
 
     """
 
     def __init__(self, data, labels,
-                 input_shape,
+                 input_shape=[None, 28, 28, 1],
                  input_dtype="float64",
-                 batch_size=1,
+                 batch_size=32,
                  seed=10,
                  data_augmentation=None,
-                 normalizer=Linear()):
+                 normalizer=Linear(),
+                 prefetch=False,
+                 prefetch_capacity=10):
+
+        # Setting the seed for the pseudo random number generator
         self.seed = seed
         numpy.random.seed(seed)
 
@@ -58,10 +68,9 @@ class Base(object):
         # TODO: Check if the bacth size is higher than the input data
         self.batch_size = batch_size
 
+        # Preparing the inputs
         self.data = data
-        self.shape = tuple([batch_size] + input_shape)
         self.input_shape = tuple(input_shape)
-
         self.labels = labels
         self.possible_labels = list(set(self.labels))
 
@@ -72,43 +81,72 @@ class Base(object):
         self.indexes = numpy.array(range(self.n_samples))
         numpy.random.shuffle(self.indexes)
 
-        self.data_placeholder = None
-        self.label_placeholder = None
-
+        # Use data data augmentation?
         self.data_augmentation = data_augmentation
-        self.deployment_shape = [None] + list(input_shape)
 
-    def set_placeholders(self, data, label):
-        self.data_placeholder = data
-        self.label_placeholder = label
+        # Preparing placeholders
+        self.data_ph = None
+        self.label_ph = None
+        # Prefetch variables
+        self.prefetch = prefetch
+        self.data_ph_from_queue = None
+        self.label_ph_from_queue = None
 
-    def get_batch(self):
+    def create_placeholders(self):
         """
-        Shuffle dataset and get a random batch.
+        Create place holder instances
+        
+        :return: 
         """
-        raise NotImplementedError("Method not implemented in this level. You should use one of the derived classes.")
+        with tf.name_scope("Input"):
+
+            self.data_ph = tf.placeholder(tf.float32, shape=self.input_shape, name="data")
+            self.label_ph = tf.placeholder(tf.int64, shape=[None], name="label")
+
+            # If prefetch, setup the queue to feed data
+            if self.prefetch:
+                queue = tf.FIFOQueue(capacity=self.prefetch_capacity,
+                                     dtypes=[tf.float32, tf.int64],
+                                     shapes=[self.input_shape[1:], []])
 
-    def get_placeholders(self, name=""):
+                # Fetching the place holders from the queue
+                self.enqueue_op = queue.enqueue_many([self.data_ph, self.label_ph])
+                self.data_ph_from_queue, self.label_ph_from_queue = queue.dequeue_many(self.batch_size)
+
+            else:
+                self.data_ph_from_queue = self.data_ph
+                self.label_ph_from_queue = self.label_ph
+
+    def __call__(self, element, from_queue=False):
         """
-        Returns a place holder with the size of your batch
+        Return the necessary placeholder
+        
         """
 
-        if self.data_placeholder is None:
-            self.data_placeholder = tf.placeholder(tf.float32, shape=self.shape, name=name)
+        if not element in ["data", "label"]:
+            raise ValueError("Value '{0}' invalid. Options available are {1}".format(element, self.placeholder_options))
+
+        # If None, create the placeholders from scratch
+        if self.data_ph is None:
+            self.create_placeholders()
 
-        if self.label_placeholder is None:
-            self.label_placeholder = tf.placeholder(tf.int64, shape=self.shape[0])
+        if element == "data":
+            if from_queue:
+                return self.data_ph_from_queue
+            else:
+                return self.data_ph
 
-        return [self.data_placeholder, self.label_placeholder]
+        else:
+            if from_queue:
+                return self.label_ph_from_queue
+            else:
+                return self.label_ph
 
-    def get_placeholders_forprefetch(self, name=""):
+    def get_batch(self):
         """
-        Returns a place holder with the size of your batch
+        Shuffle dataset and get a random batch.
         """
-        if self.data_placeholder is None:
-            self.data_placeholder = tf.placeholder(tf.float32, shape=tuple([None] + list(self.shape[1:])), name=name)
-            self.label_placeholder = tf.placeholder(tf.int64, shape=[None, ])
-        return [self.data_placeholder, self.label_placeholder]
+        raise NotImplementedError("Method not implemented in this level. You should use one of the derived classes.")
 
     def bob2skimage(self, bob_image):
         """
@@ -167,10 +205,6 @@ class Base(object):
         else:
             return data
 
-    def reshape_for_deploy(self, data):
-        shape = tuple([1] + list(data.shape))
-        return numpy.reshape(data, shape)
-
     def normalize_sample(self, x):
         """
         Normalize the sample.
diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py
index e83fb345e7ac8b55e068ac3a9904997a77de1ca1..0a88bc3c44e84966829f6ee4dedc85c5f63961cd 100644
--- a/bob/learn/tensorflow/network/__init__.py
+++ b/bob/learn/tensorflow/network/__init__.py
@@ -10,6 +10,7 @@ from .VGG16 import VGG16
 from .VGG16_mod import VGG16_mod
 from .SimpleAudio import SimpleAudio
 from .Embedding import Embedding
+#from .Input import Input
 
 # gets sphinx autodoc done right - don't remove it
 def __appropriate__(*args):
diff --git a/bob/learn/tensorflow/test/test_cnn.py b/bob/learn/tensorflow/test/test_cnn.py
index 59e6223828ada1ee2f1e693d034c0602e06919a9..4122a547ccc353172db50d3a6574b42847aebe43 100644
--- a/bob/learn/tensorflow/test/test_cnn.py
+++ b/bob/learn/tensorflow/test/test_cnn.py
@@ -5,7 +5,7 @@
 
 import numpy
 from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, ImageAugmentation, ScaleFactor
-from bob.learn.tensorflow.network import Chopra, SequenceNetwork
+from bob.learn.tensorflow.network import Chopra
 from bob.learn.tensorflow.loss import BaseLoss, ContrastiveLoss, TripletLoss
 from bob.learn.tensorflow.trainers import Trainer, SiameseTrainer, TripletTrainer, constant
 from .test_cnn_scratch import validate_network
@@ -23,7 +23,7 @@ Some unit tests for the datashuffler
 """
 
 batch_size = 32
-validation_batch_size = 400
+validation_batch_size = 32
 iterations = 300
 seed = 10
 
@@ -77,6 +77,7 @@ def dummy_experiment(data_s, architecture):
 
 def test_cnn_trainer():
 
+    # Loading data
     train_data, train_labels, validation_data, validation_labels = load_mnist()
     train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
     validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
@@ -84,40 +85,45 @@ def test_cnn_trainer():
     # Creating datashufflers
     data_augmentation = ImageAugmentation()
     train_data_shuffler = Memory(train_data, train_labels,
-                                 input_shape=[28, 28, 1],
+                                 input_shape=[None, 28, 28, 1],
                                  batch_size=batch_size,
                                  data_augmentation=data_augmentation,
                                  normalizer=ScaleFactor())
 
+    validation_data_shuffler = Memory(validation_data, validation_labels,
+                                      input_shape=[None, 28, 28, 1],
+                                      batch_size=batch_size,
+                                      data_augmentation=data_augmentation,
+                                      normalizer=ScaleFactor())
+
     directory = "./temp/cnn"
 
     # Loss for the softmax
     loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
 
-    inputs = {}
-    inputs['data'] = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name="train_data")
-    inputs['label'] = tf.placeholder(tf.int64, shape=[None], name="train_label")
-
     # Preparing the architecture
     architecture = Chopra(seed=seed,
                           fc1_output=10)
-    graph = architecture(inputs['data'])
-    embedding = Embedding(inputs['data'], graph)
+    input_pl = train_data_shuffler("data", from_queue=True)
+    graph = architecture(input_pl)
+    embedding = Embedding(train_data_shuffler("data", from_queue=False), graph)
 
     # One graph trainer
-    trainer = Trainer(inputs=inputs,
-                      graph=graph,
-                      loss=loss,
+    trainer = Trainer(train_data_shuffler,
                       iterations=iterations,
                       analizer=None,
-                      prefetch=False,
-                      learning_rate=constant(0.01, name="regular_lr"),
-                      optimizer=tf.train.GradientDescentOptimizer(0.01),
                       temp_dir=directory
                       )
-    trainer.train(train_data_shuffler)
+    trainer.create_network_from_scratch(graph=graph,
+                                        loss=loss,
+                                        learning_rate=constant(0.01, name="regular_lr"),
+                                        optimizer=tf.train.GradientDescentOptimizer(0.01),
+                                        )
+    trainer.train()
+    #trainer.train(validation_data_shuffler)
+
+    # Using embedding to compute the accuracy
     accuracy = validate_network(embedding, validation_data, validation_labels)
-
     # At least 80% of accuracy
     assert accuracy > 80.
     shutil.rmtree(directory)
@@ -165,8 +171,6 @@ def test_siamesecnn_trainer():
                              optimizer=tf.train.AdamOptimizer(name="adam_siamese"),
                              temp_dir=directory
                              )
-
-    import ipdb; ipdb.set_trace();
     trainer.train(train_data_shuffler)
 
     eer = dummy_experiment(validation_data_shuffler, architecture)
diff --git a/bob/learn/tensorflow/test/test_cnn_pretrained_model.py b/bob/learn/tensorflow/test/test_cnn_pretrained_model.py
index 62ee45545404f8499ac05338773bb542a899cf06..645e2ae80e82395d4a209e3d15a7d30f091d954b 100644
--- a/bob/learn/tensorflow/test/test_cnn_pretrained_model.py
+++ b/bob/learn/tensorflow/test/test_cnn_pretrained_model.py
@@ -10,9 +10,11 @@ from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, Triplet
 from bob.learn.tensorflow.loss import BaseLoss, TripletLoss, ContrastiveLoss
 from bob.learn.tensorflow.trainers import Trainer, constant, TripletTrainer, SiameseTrainer
 from bob.learn.tensorflow.utils import load_mnist
-from bob.learn.tensorflow.network import SequenceNetwork
 from bob.learn.tensorflow.layers import Conv2D, FullyConnected
+from bob.learn.tensorflow.network import Embedding
 from .test_cnn import dummy_experiment
+from .test_cnn_scratch import validate_network
+
 
 import tensorflow as tf
 import shutil
@@ -23,46 +25,38 @@ Some unit tests that create networks on the fly and load variables
 
 batch_size = 16
 validation_batch_size = 400
-iterations = 50
+iterations =300
 seed = 10
 
 
-def scratch_network():
+def scratch_network(input_pl):
     # Creating a random network
-    scratch = SequenceNetwork(default_feature_layer="fc1")
-    scratch.add(Conv2D(name="conv1", kernel_size=3,
-                       filters=10,
-                       activation=tf.nn.tanh,
-                       batch_norm=False))
-    scratch.add(FullyConnected(name="fc1", output_dim=10,
-                               activation=None,
-                               batch_norm=False
-                               ))
-
-    return scratch
+    slim = tf.contrib.slim
 
+    initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=10)
 
-def validate_network(validation_data, validation_labels, network):
-    # Testing
-    validation_data_shuffler = Memory(validation_data, validation_labels,
-                                      input_shape=[28, 28, 1],
-                                      batch_size=validation_batch_size)
-
-    [data, labels] = validation_data_shuffler.get_batch()
-    predictions = network.predict(data)
-    accuracy = 100. * numpy.sum(predictions == labels) / predictions.shape[0]
+    scratch = slim.conv2d(input_pl, 10, 3, activation_fn=tf.nn.tanh,
+                          stride=1,
+                          weights_initializer=initializer,
+                          scope='conv1')
+    scratch = slim.flatten(scratch, scope='flatten1')
+    scratch = slim.fully_connected(scratch, 10,
+                                   weights_initializer=initializer,
+                                   activation_fn=None,
+                                   scope='fc1')
 
-    return accuracy
+    return scratch
 
 
 def test_cnn_pretrained():
+    # Preparing input data
     train_data, train_labels, validation_data, validation_labels = load_mnist()
     train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
 
     # Creating datashufflers
     data_augmentation = ImageAugmentation()
     train_data_shuffler = Memory(train_data, train_labels,
-                                 input_shape=[28, 28, 1],
+                                 input_shape=[None, 28, 28, 1],
                                  batch_size=batch_size,
                                  data_augmentation=data_augmentation)
     validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
@@ -71,51 +65,55 @@ def test_cnn_pretrained():
     directory2 = "./temp/cnn2"
 
     # Creating a random network
-    scratch = scratch_network()
+    input_pl = train_data_shuffler("data", from_queue=True)
+    graph = scratch_network(input_pl)
+    embedding = Embedding(train_data_shuffler("data", from_queue=False), graph)
 
     # Loss for the softmax
     loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
 
     # One graph trainer
-    trainer = Trainer(architecture=scratch,
-                      loss=loss,
+    # One graph trainer
+    trainer = Trainer(train_data_shuffler,
                       iterations=iterations,
                       analizer=None,
-                      prefetch=False,
-                      learning_rate=constant(0.05, name="regular_lr"),
-                      optimizer=tf.train.AdamOptimizer(name="adam_pretrained_model"),
                       temp_dir=directory
                       )
-
-    trainer.train(train_data_shuffler)
-    accuracy = validate_network(validation_data, validation_labels, scratch)
-    assert accuracy > 85
-
-    del scratch
+    trainer.create_network_from_scratch(graph=graph,
+                                        loss=loss,
+                                        learning_rate=constant(0.01, name="regular_lr"),
+                                        optimizer=tf.train.GradientDescentOptimizer(0.01),
+                                        )
+    trainer.train()
+    accuracy = validate_network(embedding, validation_data, validation_labels)
+    assert accuracy > 80
+
+    del graph
     del loss
     del trainer
-
     # Training the network using a pre trained model
     loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean, name="loss")
-    scratch = scratch_network()
-    trainer = Trainer(architecture=scratch,
-                      loss=loss,
-                      iterations=iterations + 200,
+    graph = scratch_network(input_pl)
+
+    # One graph trainer
+    trainer = Trainer(train_data_shuffler,
+                      iterations=iterations,
                       analizer=None,
-                      prefetch=False,
-                      learning_rate=None,
-                      temp_dir=directory2,
-                      model_from_file=os.path.join(directory, "model.ckp")
+                      temp_dir=directory
                       )
+    trainer.create_network_from_file(os.path.join(directory, "model.ckp"))
 
-    trainer.train(train_data_shuffler)
+    import ipdb;
+    ipdb.set_trace()
 
-    accuracy = validate_network(validation_data, validation_labels, scratch)
+    trainer.train()
+
+    accuracy = validate_network(embedding, validation_data, validation_labels)
     assert accuracy > 90
     shutil.rmtree(directory)
     shutil.rmtree(directory2)
 
-    del scratch
+    del graph
     del loss
     del trainer
 
diff --git a/bob/learn/tensorflow/test/test_cnn_scratch.py b/bob/learn/tensorflow/test/test_cnn_scratch.py
index 2fa05da7da9ae0ba5582852b0542f8aae5607ece..28e98ad78e1d6ae3a1b81fd577b66b1282822f15 100644
--- a/bob/learn/tensorflow/test/test_cnn_scratch.py
+++ b/bob/learn/tensorflow/test/test_cnn_scratch.py
@@ -48,7 +48,7 @@ def scratch_network():
 def validate_network(embedding, validation_data, validation_labels):
     # Testing
     validation_data_shuffler = Memory(validation_data, validation_labels,
-                                      input_shape=[28, 28, 1],
+                                      input_shape=[None, 28, 28, 1],
                                       batch_size=validation_batch_size,
                                       normalizer=ScaleFactor())
 
diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py
index 8e1ca9a910d871a56269c1bbf610db4169ecbca5..4c3191c59460581babfd70fab0a95e893d21329d 100644
--- a/bob/learn/tensorflow/trainers/Trainer.py
+++ b/bob/learn/tensorflow/trainers/Trainer.py
@@ -69,119 +69,138 @@ class Trainer(object):
     """
 
     def __init__(self,
-                 inputs,
-                 graph,
-                 optimizer=tf.train.AdamOptimizer(),
-                 use_gpu=False,
-                 loss=None,
-                 temp_dir="cnn",
-
-                 # Learning rate
-                 learning_rate=None,
+                 train_data_shuffler,
 
                  ###### training options ##########
-                 convergence_threshold=0.01,
                  iterations=5000,
                  snapshot=500,
                  validation_snapshot=100,
-                 prefetch=False,
 
                  ## Analizer
                  analizer=SoftmaxAnalizer(),
 
-                 ### Pretrained model
-                 model_from_file="",
+                 # Temporatu dir
+                 temp_dir="cnn",
 
                  verbosity_level=2):
 
-        self.inputs = inputs
-        self.graph = graph
-        self.loss = loss
-
-        if not isinstance(self.graph, tf.Tensor):
-            raise ValueError("Expected a tf.Tensor as input for the keywork `graph`")
-
-        self.predictor = self.loss(self.graph, inputs['label'])
-
-        self.optimizer_class = optimizer
-        self.use_gpu = use_gpu
+        self.train_data_shuffler = train_data_shuffler
         self.temp_dir = temp_dir
 
-        if learning_rate is None and model_from_file == "":
-            self.learning_rate = constant()
-        else:
-            self.learning_rate = learning_rate
-
         self.iterations = iterations
         self.snapshot = snapshot
         self.validation_snapshot = validation_snapshot
-        self.convergence_threshold = convergence_threshold
-        self.prefetch = prefetch
 
         # Training variables used in the fit
-        self.optimizer = None
-        self.training_graph = None
-        self.train_data_shuffler = None
         self.summaries_train = None
         self.train_summary_writter = None
         self.thread_pool = None
 
         # Validation data
-        self.validation_graph = None
         self.validation_summary_writter = None
 
         # Analizer
         self.analizer = analizer
-
-        self.thread_pool = None
-        self.enqueue_op = None
         self.global_step = None
 
-        self.model_from_file = model_from_file
         self.session = None
 
+        self.graph = None
+        self.loss = None
+        self.predictor = None
+        self.optimizer_class = None
+        self.learning_rate = None
+        # Training variables used in the fit
+        self.optimizer = None
+        self.data_ph = None
+        self.label_ph = None
+        self.saver = None
+
         bob.core.log.set_verbosity_level(logger, verbosity_level)
 
-    def __del__(self):
-        tf.reset_default_graph()
+        # Creating the session
+        self.session = Session.instance(new=True).session
+        self.from_scratch = True
 
-    """
-    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
-        Computes the graph for the trainer.
+    def create_network_from_scratch(self,
+                                    graph,
+                                    optimizer=tf.train.AdamOptimizer(),
+                                    loss=None,
 
-        ** Parameters **
+                                    # Learning rate
+                                    learning_rate=None,
+                                    ):
+
+        self.saver = tf.train.Saver(var_list=tf.global_variables())
+
+        self.data_ph = self.train_data_shuffler("data")
+        self.label_ph = self.train_data_shuffler("label")
+        self.graph = graph
+        self.loss = loss
+        self.predictor = self.loss(self.graph, self.train_data_shuffler("label", from_queue=True))
 
-            data_shuffler: Data shuffler
-            prefetch: Uses prefetch
-            name: Name of the graph
-            training: Is it a training graph?
+        self.optimizer_class = optimizer
+        self.learning_rate = learning_rate
 
-        # Defining place holders
-        if prefetch:
-            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
+        # TODO: find an elegant way to provide this as a parameter of the trainer
+        self.global_step = tf.Variable(0, trainable=False, name="global_step")
+        tf.add_to_collection("global_step", self.global_step)
 
-            # Defining a placeholder queue for prefetching
-            queue = tf.FIFOQueue(capacity=10,
-                                 dtypes=[tf.float32, tf.int64],
-                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])
+        tf.add_to_collection("graph", self.graph)
+        tf.add_to_collection("predictor", self.predictor)
 
-            # Fetching the place holders from the queue
-            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
-            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)
+        tf.add_to_collection("data_ph", self.data_ph)
+        tf.add_to_collection("label_ph", self.label_ph)
 
-            # Creating the architecture for train and validation
-            if not isinstance(self.architecture, SequenceNetwork):
-                raise ValueError("The variable `architecture` must be an instance of "
-                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
-        else:
-            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
+        # Preparing the optimizer
+        self.optimizer_class._learning_rate = self.learning_rate
+        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
+        tf.add_to_collection("optimizer", self.optimizer)
+        tf.add_to_collection("learning_rate", self.learning_rate)
 
-        # Creating graphs and defining the loss
-        network_graph = self.architecture.compute_graph(feature_batch, training=training)
-        graph = self.loss(network_graph, label_batch)
+        self.summaries_train = self.create_general_summary()
+        tf.add_to_collection("summaries_train", self.summaries_train)
 
-        return graph
-    """
+        # Creating the variables
+        tf.global_variables_initializer().run(session=self.session)
+
+    def create_network_from_file(self, model_from_file):
+        """
+        Bootstrap all the necessary data from file
+
+         ** Parameters **
+           session: Tensorflow session
+           train_data_shuffler: Data shuffler for training
+           validation_data_shuffler: Data shuffler for validation
+
+
+        """
+        #saver = self.architecture.load(self.model_from_file, clear_devices=False)
+        self.saver = tf.train.import_meta_graph(model_from_file + ".meta")
+        self.saver.restore(self.session, model_from_file)
+
+        # Loading training graph
+        self.data_ph = tf.get_collection("data_ph")
+        self.label_ph = tf.get_collection("label_ph")
+
+        self.graph = tf.get_collection("graph")[0]
+        self.predictor = tf.get_collection("predictor")[0]
+
+        # Loding other elements
+        self.optimizer = tf.get_collection("optimizer")[0]
+        self.learning_rate = tf.get_collection("learning_rate")[0]
+        self.summaries_train = tf.get_collection("summaries_train")[0]
+        self.global_step = tf.get_collection("global_step")[0]
+        self.from_scratch = False
+
+        # Creating the variables
+        tf.global_variables_initializer().run(session=self.session)
+        import ipdb; ipdb.set_trace()
+        x=0
+
+
+    def __del__(self):
+        tf.reset_default_graph()
 
     def get_feed_dict(self, data_shuffler):
         """
@@ -193,8 +212,8 @@ class Trainer(object):
         """
         [data, labels] = data_shuffler.get_batch()
 
-        feed_dict = {self.inputs['data']: data,
-                     self.inputs['label']: labels}
+        feed_dict = {self.data_ph: data,
+                     self.label_ph: labels}
         return feed_dict
 
     def fit(self, step):
@@ -207,7 +226,7 @@ class Trainer(object):
 
         """
 
-        if self.prefetch:
+        if self.train_data_shuffler.prefetch:
             _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
                                                   self.learning_rate, self.summaries_train])
         else:
@@ -218,6 +237,27 @@ class Trainer(object):
         logger.info("Loss training set step={0} = {1}".format(step, l))
         self.train_summary_writter.add_summary(summary, step)
 
+    def compute_validation(self, data_shuffler, step):
+        """
+        Computes the loss in the validation set
+
+        ** Parameters **
+            session: Tensorflow session
+            data_shuffler: The data shuffler to be used
+            step: Iteration number
+
+        """
+        pass
+        # Opening a new session for validation
+        #feed_dict = self.get_feed_dict(data_shuffler)
+        #l, summary = self.session.run(self.predictor, self.summaries_train, feed_dict=feed_dict)
+        #train_summary_writter.add_summary(summary, step)
+
+
+        #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
+        #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
+        #logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
+
     def create_general_summary(self):
         """
         Creates a simple tensorboard summary with the value of the loss and learning rate
@@ -228,14 +268,13 @@ class Trainer(object):
         tf.summary.scalar('lr', self.learning_rate)
         return tf.summary.merge_all()
 
-    """
     def start_thread(self):
-
+        """
         Start pool of threads for pre-fetching
 
         **Parameters**
           session: Tensorflow session
-
+        """
 
         threads = []
         for n in range(3):
@@ -246,72 +285,22 @@ class Trainer(object):
         return threads
 
     def load_and_enqueue(self):
-
+        """
         Injecting data in the place holder queue
 
         **Parameters**
           session: Tensorflow session
 
-
+        """
         while not self.thread_pool.should_stop():
             [train_data, train_labels] = self.train_data_shuffler.get_batch()
-            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
 
-            feed_dict = {train_placeholder_data: train_data,
-                         train_placeholder_labels: train_labels}
+            feed_dict = {self.data_ph: train_data,
+                         self.label_ph: train_labels}
 
-            self.session.run(self.enqueue_op, feed_dict=feed_dict)
+            self.session.run(self.inputs.enqueue_op, feed_dict=feed_dict)
 
-    """
-
-    def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
-        """
-        Bootstrap all the necessary data from file
-
-         ** Parameters **
-           session: Tensorflow session
-           train_data_shuffler: Data shuffler for training
-           validation_data_shuffler: Data shuffler for validation
-
-
-        """
-        saver = self.architecture.load(self.model_from_file, clear_devices=False)
-
-        # Loading training graph
-        self.training_graph = tf.get_collection("training_graph")[0]
-
-        # Loding other elements
-        self.optimizer = tf.get_collection("optimizer")[0]
-        self.learning_rate = tf.get_collection("learning_rate")[0]
-        self.summaries_train = tf.get_collection("summaries_train")[0]
-        self.global_step = tf.get_collection("global_step")[0]
-
-        if validation_data_shuffler is not None:
-            self.validation_graph = tf.get_collection("validation_graph")[0]
-
-        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)
-
-        return saver
-
-    def compute_validation(self, data_shuffler, step):
-        """
-        Computes the loss in the validation set
-
-        ** Parameters **
-            session: Tensorflow session
-            data_shuffler: The data shuffler to be used
-            step: Iteration number
-
-        """
-        # Opening a new session for validation
-        feed_dict = self.get_feed_dict(data_shuffler)
-        l = self.session.run(self.predictor, feed_dict=feed_dict)
-
-        summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
-        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
-        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
-
-    def train(self, train_data_shuffler, validation_data_shuffler=None):
+    def train(self, validation_data_shuffler=None):
         """
         Train the network:
 
@@ -323,56 +312,32 @@ class Trainer(object):
 
         # Creating directory
         bob.io.base.create_directories_safe(self.temp_dir)
-        self.train_data_shuffler = train_data_shuffler
 
         logger.info("Initializing !!")
-        self.session = Session.instance(new=True).session
 
         # Loading a pretrained model
-        if self.model_from_file != "":
-            logger.info("Loading pretrained model from {0}".format(self.model_from_file))
-            saver = self.bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
-
-            start_step = self.global_step.eval(session=self.session)
-
-        else:
+        if self.from_scratch:
             start_step = 0
-
-            # TODO: find an elegant way to provide this as a parameter of the trainer
-            self.global_step = tf.Variable(0, trainable=False, name="global_step")
-            tf.add_to_collection("global_step", self.global_step)
-
-            # Preparing the optimizer
-            self.optimizer_class._learning_rate = self.learning_rate
-            self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
-            tf.add_to_collection("optimizer", self.optimizer)
-            tf.add_to_collection("learning_rate", self.learning_rate)
-
-            self.summaries_train = self.create_general_summary()
-            tf.add_to_collection("summaries_train", self.summaries_train)
-
-            # Train summary
-            tf.global_variables_initializer().run(session=self.session)
-
-            # Original tensorflow saver object
-            saver = tf.train.Saver(var_list=tf.global_variables())
+        else:
+            start_step = self.global_step.eval(session=self.session)
 
         #if isinstance(train_data_shuffler, OnlineSampling):
         #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)
 
         # Start a thread to enqueue data asynchronously, and hide I/O latency.
-        #if self.prefetch:
-        #    self.thread_pool = tf.train.Coordinator()
-        #    tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
-        #    threads = self.start_thread()
+        if self.train_data_shuffler.prefetch:
+            self.thread_pool = tf.train.Coordinator()
+            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
+            threads = self.start_thread()
 
         # TENSOR BOARD SUMMARY
         self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
         if validation_data_shuffler is not None:
             self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                     self.session.graph)
-
+        # Loop for
         for step in range(start_step, self.iterations):
+            # Run fit in the graph
             start = time.time()
             self.fit(step)
             end = time.time()
@@ -391,7 +356,7 @@ class Trainer(object):
             if step % self.snapshot == 0:
                 logger.info("Taking snapshot")
                 path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
-                saver.save(self.session, path)
+                self.saver.save(self.session, path)
                 #self.architecture.save(saver, path)
 
         logger.info("Training finally finished")
@@ -402,9 +367,9 @@ class Trainer(object):
 
         # Saving the final network
         path = os.path.join(self.temp_dir, 'model.ckp')
-        saver.save(self.session, path)
+        self.saver.save(self.session, path)
 
-        if self.prefetch:
+        if self.train_data_shuffler.prefetch:
             # now they should definetely stop
             self.thread_pool.request_stop()
             #self.thread_pool.join(threads)