diff --git a/bob/learn/tensorflow/datashuffler/OnlineSampling.py b/bob/learn/tensorflow/datashuffler/OnlineSampling.py
index 0f630999918191572dbdc642ea5314e84f64572f..daf989de045050087075ea2e4602642fd74ec3df 100644
--- a/bob/learn/tensorflow/datashuffler/OnlineSampling.py
+++ b/bob/learn/tensorflow/datashuffler/OnlineSampling.py
@@ -42,9 +42,14 @@ class OnLineSampling(object):
         # Feeding the placeholder
 
         if self.feature_placeholder is None:
-            self.feature_placeholder = tf.placeholder(tf.float32, shape=data.shape, name="feature")
+            shape = tuple([None] + list(data.shape[1:]))
+            self.feature_placeholder = tf.placeholder(tf.float32, shape=shape, name="feature")
             self.graph = self.feature_extractor.compute_graph(self.feature_placeholder, self.feature_extractor.default_feature_layer,
                                                               training=False)
 
+        #if self.feature_placeholder.get_shape().as_list() != list(data.shape):
+            #tf.reshape(self.feature_placeholder, tf.pack([data.shape]))
+            #self.feature_placeholder.set_shape(data.shape)
+
         feed_dict = {self.feature_placeholder: data}
         return self.session.run([self.graph], feed_dict=feed_dict)[0]
\ No newline at end of file
diff --git a/bob/learn/tensorflow/datashuffler/TripletWithSelectionDisk.py b/bob/learn/tensorflow/datashuffler/TripletWithSelectionDisk.py
index 7691c9282b3e95c0920f13a57d519311fdff992c..0b1fffff54c53d4a86b00e2d3807b7257e169040 100644
--- a/bob/learn/tensorflow/datashuffler/TripletWithSelectionDisk.py
+++ b/bob/learn/tensorflow/datashuffler/TripletWithSelectionDisk.py
@@ -11,6 +11,9 @@ from .Triplet import Triplet
 from .OnlineSampling import OnLineSampling
 from scipy.spatial.distance import euclidean
 
+import logging
+logger = logging.getLogger("bob.learn.tensorflow")
+
 
 class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
     """
@@ -110,29 +113,42 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
 
         # Selecting the classes used in the selection
         indexes = numpy.random.choice(len(self.possible_labels), self.total_identities, replace=False)
-        samples_per_identity = self.batch_size/self.total_identities
+        samples_per_identity = numpy.ceil(self.batch_size/float(self.total_identities))
         anchor_labels = numpy.ones(samples_per_identity) * self.possible_labels[indexes[0]]
+
         for i in range(1, self.total_identities):
             anchor_labels = numpy.hstack((anchor_labels,numpy.ones(samples_per_identity) * self.possible_labels[indexes[i]]))
         anchor_labels = anchor_labels[0:self.batch_size]
 
+
+
+
         data_a = numpy.zeros(shape=self.shape, dtype='float32')
         data_p = numpy.zeros(shape=self.shape, dtype='float32')
         data_n = numpy.zeros(shape=self.shape, dtype='float32')
 
+        #logger.info("Fetching anchor")
         # Fetching the anchors
         for i in range(self.shape[0]):
             data_a[i, ...] = self.get_anchor(anchor_labels[i])
         features_a = self.project(data_a)
 
         for i in range(self.shape[0]):
+            #logger.info("*********Anchor {0}".format(i))
+
             label = anchor_labels[i]
             #anchor = self.get_anchor(label)
+            #logger.info("********* Positives")
             positive, distance_anchor_positive = self.get_positive(label, features_a[i])
+            #logger.info("********* Negatives")
             negative = self.get_negative(label, features_a[i], distance_anchor_positive)
 
+            #logger.info("********* Appending")
+
             data_p[i, ...] = positive
             data_n[i, ...] = negative
+
+        #logger.info("#################")
         # Scaling
         #if self.scale:
         #    data_a *= self.scale_value
@@ -164,27 +180,37 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
         The best positive sample for the anchor is the farthest from the anchor
         """
 
+        #logger.info("****************** numpy.where")
         indexes = numpy.where(self.labels == label)[0]
+
         numpy.random.shuffle(indexes)
         indexes = indexes[
                   0:self.batch_size]  # Limiting to the batch size, otherwise the number of comparisons will explode
         distances = []
-
-        data_p = numpy.zeros(shape=self.shape, dtype='float32')
-        for i in range(self.shape[0]):
+        shape = tuple([len(indexes)] + list(self.shape[1:]))
+        data_p = numpy.zeros(shape=shape, dtype='float32')
+        #logger.info("****************** search")
+        for i in range(shape[0]):
+            #logger.info("****************** fetch")
             file_name = self.data[indexes[i], ...]
+            #logger.info("****************** load")
             data_p[i, ...] = self.load_from_file(str(file_name))
-            if self.scale:
-                data_p *= self.scale_value
+
+        if self.scale:
+            data_p *= self.scale_value
+
+        #logger.info("****************** project")
         positive_features = self.project(data_p)
 
+        #logger.info("****************** distances")
         # Projecting the positive instances
         for p in positive_features:
             distances.append(euclidean(anchor_feature, p))
 
         # Geting the max
         index = numpy.argmax(distances)
-        return self.data[indexes[index], ...], distances[index]
+        #logger.info("****************** return")
+        return data_p[index, ...], distances[index]
 
     def get_negative(self, label, anchor_feature, distance_anchor_positive):
         """
@@ -194,20 +220,29 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
         #anchor_feature = self.feature_extractor(self.reshape_for_deploy(anchor), session=self.session)
 
         # Selecting the negative samples
+        #logger.info("****************** numpy.where")
         indexes = numpy.where(self.labels != label)[0]
         numpy.random.shuffle(indexes)
         indexes = indexes[
-                  0:self.batch_size] # Limiting to the batch size, otherwise the number of comparisons will explode
+                  0:self.batch_size*3] # Limiting to the batch size, otherwise the number of comparisons will explode
 
-        data_n = numpy.zeros(shape=self.shape, dtype='float32')
-        for i in range(self.shape[0]):
+        shape = tuple([len(indexes)] + list(self.shape[1:]))
+        data_n = numpy.zeros(shape=shape, dtype='float32')
+        #logger.info("****************** search")
+        for i in range(shape[0]):
+            #logger.info("****************** fetch")
             file_name = self.data[indexes[i], ...]
+            #logger.info("****************** load")
             data_n[i, ...] = self.load_from_file(str(file_name))
-            if self.scale:
-                data_n *= self.scale_value
+
+        if self.scale:
+            data_n *= self.scale_value
+
+        #logger.info("****************** project")
         negative_features = self.project(data_n)
 
         distances = []
+        #logger.info("****************** distances")
         for n in negative_features:
             d = euclidean(anchor_feature, n)
 
@@ -222,6 +257,7 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
 
         # if the semi-hardest is inf take the first
         if numpy.isinf(distances[index]):
+            logger.info("SEMI-HARD negative not found, took the first one")
             index = 0
-
-        return self.data[indexes[index], ...]
+        #logger.info("****************** return")
+        return data_n[index, ...]
diff --git a/bob/learn/tensorflow/datashuffler/__init__.py b/bob/learn/tensorflow/datashuffler/__init__.py
index 55d3b2232033f1745dc707a5fb317ccc876b6169..2c5751fad9b8b2a7887eaeedf843e41d703da060 100644
--- a/bob/learn/tensorflow/datashuffler/__init__.py
+++ b/bob/learn/tensorflow/datashuffler/__init__.py
@@ -15,5 +15,8 @@ from .SiameseDisk import SiameseDisk
 from .TripletDisk import TripletDisk
 from .TripletWithSelectionDisk import TripletWithSelectionDisk
 
+from .DataAugmentation import DataAugmentation
+from .ImageAugmentation import ImageAugmentation
+
 # gets sphinx autodoc done right - don't remove it
 __all__ = [_ for _ in dir() if not _.startswith('_')]
diff --git a/bob/learn/tensorflow/layers/FullyConnected.py b/bob/learn/tensorflow/layers/FullyConnected.py
index 1277bf88b3c1861d312beef51e557d2c2d5280bd..992162e0b591fe882105e1b8fec6d34a00ce475e 100644
--- a/bob/learn/tensorflow/layers/FullyConnected.py
+++ b/bob/learn/tensorflow/layers/FullyConnected.py
@@ -61,7 +61,9 @@ class FullyConnected(Layer):
 
             if len(self.input_layer.get_shape()) == 4:
                 shape = self.input_layer.get_shape().as_list()
-                fc = tf.reshape(self.input_layer, [shape[0], shape[1] * shape[2] * shape[3]])
+                #fc = tf.reshape(self.input_layer, [shape[0], shape[1] * shape[2] * shape[3]])
+                fc = tf.reshape(self.input_layer, [-1, shape[1] * shape[2] * shape[3]])
+
             else:
                 fc = self.input_layer
 
diff --git a/bob/learn/tensorflow/script/train_siamese_casia_webface.py b/bob/learn/tensorflow/script/train_siamese_casia_webface.py
index 9506fb4b682f95f234902400dd2837988072d381..80ff95390a040319932f12f4ad2f007e6e88381a 100644
--- a/bob/learn/tensorflow/script/train_siamese_casia_webface.py
+++ b/bob/learn/tensorflow/script/train_siamese_casia_webface.py
@@ -22,10 +22,10 @@ from docopt import docopt
 import tensorflow as tf
 from .. import util
 SEED = 10
-from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler
+from bob.learn.tensorflow.datashuffler import TripletDisk, TripletWithSelectionDisk
 from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy
-from bob.learn.tensorflow.trainers import SiameseTrainer
-from bob.learn.tensorflow.loss import ContrastiveLoss
+from bob.learn.tensorflow.trainers import SiameseTrainer, TripletTrainer
+from bob.learn.tensorflow.loss import ContrastiveLoss, TripletLoss
 import numpy
 
 
@@ -56,9 +56,9 @@ def main():
         extension="")
                         for o in train_objects]
 
-    train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
-                                           input_shape=[125, 125, 3],
-                                           batch_size=BATCH_SIZE)
+    train_data_shuffler = TripletWithSelectionDisk(train_file_names, train_labels,
+                                                   input_shape=[125, 125, 3],
+                                                   batch_size=BATCH_SIZE)
 
     # Preparing train set
     directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed"
@@ -70,20 +70,27 @@ def main():
         extension=".hdf5")
                              for o in validation_objects]
 
-    validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
-                                                input_shape=[125, 125, 3],
-                                                batch_size=VALIDATION_BATCH_SIZE)
+    validation_data_shuffler = TripletDisk(validation_file_names, validation_labels,
+                                           input_shape=[125, 125, 3],
+                                           batch_size=VALIDATION_BATCH_SIZE)
     # Preparing the architecture
     # LENET PAPER CHOPRA
     architecture = Chopra(seed=SEED)
 
-    loss = ContrastiveLoss(contrastive_margin=50.)
-    optimizer = tf.train.GradientDescentOptimizer(0.00001)
-    trainer = SiameseTrainer(architecture=architecture,
-                             loss=loss,
+    #loss = ContrastiveLoss(contrastive_margin=50.)
+    #optimizer = tf.train.GradientDescentOptimizer(0.00001)
+    #trainer = SiameseTrainer(architecture=architecture,
+    #                         loss=loss,
+    #                         iterations=ITERATIONS,
+    #                         snapshot=VALIDATION_TEST,
+    #                         optimizer=optimizer)
+
+    loss = TripletLoss(margin=4.)
+    trainer = TripletTrainer(architecture=architecture, loss=loss,
                              iterations=ITERATIONS,
-                             snapshot=VALIDATION_TEST,
-                             optimizer=optimizer)
+                             prefetch=False,
+                             temp_dir="./LOGS_CASIA/triplet-cnn")
+
 
     trainer.train(train_data_shuffler, validation_data_shuffler)
     #trainer.train(train_data_shuffler)