From 711641bac18ef4e34cbdb58534cc5c2397612645 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 17 Nov 2016 12:46:40 +0100
Subject: [PATCH] Added tests for the triplet data shufflers

---
 .../datashuffler/TripletWithSelectionDisk.py  | 35 +++---------
 .../tensorflow/test/test_datashuffler.py      | 54 ++++++++++++++++---
 2 files changed, 52 insertions(+), 37 deletions(-)

diff --git a/bob/learn/tensorflow/datashuffler/TripletWithSelectionDisk.py b/bob/learn/tensorflow/datashuffler/TripletWithSelectionDisk.py
index a41deb99..2a09b67c 100644
--- a/bob/learn/tensorflow/datashuffler/TripletWithSelectionDisk.py
+++ b/bob/learn/tensorflow/datashuffler/TripletWithSelectionDisk.py
@@ -52,7 +52,6 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
     def __init__(self, data, labels,
                  input_shape,
                  input_dtype="float64",
-                 scale=True,
                  batch_size=1,
                  seed=10,
                  data_augmentation=None,
@@ -64,7 +63,6 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
             labels=labels,
             input_shape=input_shape,
             input_dtype=input_dtype,
-            scale=scale,
             batch_size=batch_size,
             seed=seed,
             data_augmentation=data_augmentation,
@@ -93,13 +91,9 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
 
         for i in range(self.shape[0]):
             file_name_a, file_name_p, file_name_n = self.get_one_triplet(self.data, self.labels)
-            sample_a[i, ...] = self.load_from_file(str(file_name_a))
-            sample_p[i, ...] = self.load_from_file(str(file_name_p))
-            sample_n[i, ...] = self.load_from_file(str(file_name_n))
-
-        sample_a = self.normalize_sample(sample_a)
-        sample_p = self.normalize_sample(sample_p)
-        sample_n = self.normalize_sample(sample_n)
+            sample_a[i, ...] = self.normalize_sample(self.load_from_file(str(file_name_a)))
+            sample_p[i, ...] = self.normalize_sample(self.load_from_file(str(file_name_p)))
+            sample_n[i, ...] = self.normalize_sample(self.load_from_file(str(file_name_n)))
 
         return [sample_a, sample_p, sample_n]
 
@@ -181,7 +175,6 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
         The best positive sample for the anchor is the farthest from the anchor
         """
 
-        #logger.info("****************** numpy.where")
         indexes = numpy.where(self.labels == label)[0]
 
         numpy.random.shuffle(indexes)
@@ -190,26 +183,19 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
         distances = []
         shape = tuple([len(indexes)] + list(self.shape[1:]))
         sample_p = numpy.zeros(shape=shape, dtype='float32')
-        #logger.info("****************** search")
+
         for i in range(shape[0]):
-            #logger.info("****************** fetch")
             file_name = self.data[indexes[i], ...]
-            #logger.info("****************** load")
-            sample_p[i, ...] = self.load_from_file(str(file_name))
+            sample_p[i, ...] =  self.normalize_sample(self.load_from_file(str(file_name)))
 
-        sample_p = self.normalize_sample(sample_p)
-
-        #logger.info("****************** project")
         embedding_p = self.project(sample_p)
 
-        #logger.info("****************** distances")
         # Projecting the positive instances
         for p in embedding_p:
             distances.append(euclidean(embedding_a, p))
 
         # Geting the max
         index = numpy.argmax(distances)
-        #logger.info("****************** return")
         return sample_p[index, ...], distances[index]
 
     def get_negative(self, label, embedding_a, distance_anchor_positive):
@@ -220,7 +206,6 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
         #anchor_feature = self.feature_extractor(self.reshape_for_deploy(anchor), session=self.session)
 
         # Selecting the negative samples
-        #logger.info("****************** numpy.where")
         indexes = numpy.where(self.labels != label)[0]
         numpy.random.shuffle(indexes)
         indexes = indexes[
@@ -228,20 +213,13 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
 
         shape = tuple([len(indexes)] + list(self.shape[1:]))
         sample_n = numpy.zeros(shape=shape, dtype='float32')
-        #logger.info("****************** search")
         for i in range(shape[0]):
-            #logger.info("****************** fetch")
             file_name = self.data[indexes[i], ...]
-            #logger.info("****************** load")
-            sample_n[i, ...] = self.load_from_file(str(file_name))
-
-        sample_n = self.normalize_sample(sample_n)
+            sample_n[i, ...] =  self.normalize_sample(self.load_from_file(str(file_name)))
 
-        #logger.info("****************** project")
         embedding_n = self.project(sample_n)
 
         distances = []
-        #logger.info("****************** distances")
         for n in embedding_n:
             d = euclidean(embedding_a, n)
 
@@ -258,5 +236,4 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
         if numpy.isinf(distances[index]):
             logger.info("SEMI-HARD negative not found, took the first one")
             index = 0
-        #logger.info("****************** return")
         return sample_n[index, ...]
diff --git a/bob/learn/tensorflow/test/test_datashuffler.py b/bob/learn/tensorflow/test/test_datashuffler.py
index 648fdbed..60abbc95 100644
--- a/bob/learn/tensorflow/test/test_datashuffler.py
+++ b/bob/learn/tensorflow/test/test_datashuffler.py
@@ -4,7 +4,8 @@
 # @date: Thu 13 Oct 2016 13:35 CEST
 
 import numpy
-from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, Disk, SiameseDisk, TripletDisk
+from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, Disk, SiameseDisk, TripletDisk, \
+    TripletWithFastSelectionDisk, TripletWithSelectionDisk
 import pkg_resources
 from bob.learn.tensorflow.utils import load_mnist
 import os
@@ -15,7 +16,6 @@ Some unit tests for the datashuffler
 
 
 def get_dummy_files():
-
     base_path = pkg_resources.resource_filename(__name__, 'data/dummy_database')
     files = []
     clients = []
@@ -28,7 +28,6 @@ def get_dummy_files():
 
 
 def test_memory_shuffler():
-
     train_data, train_labels, validation_data, validation_labels = load_mnist()
     train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
 
@@ -50,7 +49,6 @@ def test_memory_shuffler():
 
 
 def test_siamesememory_shuffler():
-
     train_data, train_labels, validation_data, validation_labels = load_mnist()
     train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
 
@@ -74,7 +72,6 @@ def test_siamesememory_shuffler():
 
 
 def test_tripletmemory_shuffler():
-
     train_data, train_labels, validation_data, validation_labels = load_mnist()
     train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
 
@@ -98,7 +95,6 @@ def test_tripletmemory_shuffler():
 
 
 def test_disk_shuffler():
-
     train_data, train_labels = get_dummy_files()
 
     batch_shape = [2, 125, 125, 3]
@@ -119,7 +115,6 @@ def test_disk_shuffler():
 
 
 def test_siamesedisk_shuffler():
-
     train_data, train_labels = get_dummy_files()
 
     batch_shape = [2, 125, 125, 3]
@@ -142,7 +137,6 @@ def test_siamesedisk_shuffler():
 
 
 def test_tripletdisk_shuffler():
-
     train_data, train_labels = get_dummy_files()
 
     batch_shape = [1, 125, 125, 3]
@@ -164,3 +158,47 @@ def test_tripletdisk_shuffler():
     assert placeholders[2].get_shape().as_list() == batch_shape
 
 
+def test_triplet_fast_selection_disk_shuffler():
+    train_data, train_labels = get_dummy_files()
+
+    batch_shape = [1, 125, 125, 3]
+
+    data_shuffler = TripletWithFastSelectionDisk(train_data, train_labels,
+                                                 input_shape=batch_shape[1:],
+                                                 total_identities=1,
+                                                 batch_size=batch_shape[0])
+
+    batch = data_shuffler.get_batch()
+
+    assert len(batch) == 3
+    assert batch[0].shape == tuple(batch_shape)
+    assert batch[1].shape == tuple(batch_shape)
+    assert batch[2].shape == tuple(batch_shape)
+
+    placeholders = data_shuffler.get_placeholders(name="train")
+    assert placeholders[0].get_shape().as_list() == batch_shape
+    assert placeholders[1].get_shape().as_list() == batch_shape
+    assert placeholders[2].get_shape().as_list() == batch_shape
+
+
+def test_triplet_selection_disk_shuffler():
+    train_data, train_labels = get_dummy_files()
+
+    batch_shape = [1, 125, 125, 3]
+
+    data_shuffler = TripletWithSelectionDisk(train_data, train_labels,
+                                             input_shape=batch_shape[1:],
+                                             total_identities=1,
+                                             batch_size=batch_shape[0])
+
+    batch = data_shuffler.get_batch()
+
+    assert len(batch) == 3
+    assert batch[0].shape == tuple(batch_shape)
+    assert batch[1].shape == tuple(batch_shape)
+    assert batch[2].shape == tuple(batch_shape)
+
+    placeholders = data_shuffler.get_placeholders(name="train")
+    assert placeholders[0].get_shape().as_list() == batch_shape
+    assert placeholders[1].get_shape().as_list() == batch_shape
+    assert placeholders[2].get_shape().as_list() == batch_shape
-- 
GitLab