Skip to content
Snippets Groups Projects
Commit 711641ba authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Added tests for the triplet data shufflers

parent 95398b3b
No related branches found
No related tags found
No related merge requests found
......@@ -52,7 +52,6 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
def __init__(self, data, labels,
input_shape,
input_dtype="float64",
scale=True,
batch_size=1,
seed=10,
data_augmentation=None,
......@@ -64,7 +63,6 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
labels=labels,
input_shape=input_shape,
input_dtype=input_dtype,
scale=scale,
batch_size=batch_size,
seed=seed,
data_augmentation=data_augmentation,
......@@ -93,13 +91,9 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
for i in range(self.shape[0]):
file_name_a, file_name_p, file_name_n = self.get_one_triplet(self.data, self.labels)
sample_a[i, ...] = self.load_from_file(str(file_name_a))
sample_p[i, ...] = self.load_from_file(str(file_name_p))
sample_n[i, ...] = self.load_from_file(str(file_name_n))
sample_a = self.normalize_sample(sample_a)
sample_p = self.normalize_sample(sample_p)
sample_n = self.normalize_sample(sample_n)
sample_a[i, ...] = self.normalize_sample(self.load_from_file(str(file_name_a)))
sample_p[i, ...] = self.normalize_sample(self.load_from_file(str(file_name_p)))
sample_n[i, ...] = self.normalize_sample(self.load_from_file(str(file_name_n)))
return [sample_a, sample_p, sample_n]
......@@ -181,7 +175,6 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
The best positive sample for the anchor is the farthest from the anchor
"""
#logger.info("****************** numpy.where")
indexes = numpy.where(self.labels == label)[0]
numpy.random.shuffle(indexes)
......@@ -190,26 +183,19 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
distances = []
shape = tuple([len(indexes)] + list(self.shape[1:]))
sample_p = numpy.zeros(shape=shape, dtype='float32')
#logger.info("****************** search")
for i in range(shape[0]):
#logger.info("****************** fetch")
file_name = self.data[indexes[i], ...]
#logger.info("****************** load")
sample_p[i, ...] = self.load_from_file(str(file_name))
sample_p[i, ...] = self.normalize_sample(self.load_from_file(str(file_name)))
sample_p = self.normalize_sample(sample_p)
#logger.info("****************** project")
embedding_p = self.project(sample_p)
#logger.info("****************** distances")
# Projecting the positive instances
for p in embedding_p:
distances.append(euclidean(embedding_a, p))
# Geting the max
index = numpy.argmax(distances)
#logger.info("****************** return")
return sample_p[index, ...], distances[index]
def get_negative(self, label, embedding_a, distance_anchor_positive):
......@@ -220,7 +206,6 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
#anchor_feature = self.feature_extractor(self.reshape_for_deploy(anchor), session=self.session)
# Selecting the negative samples
#logger.info("****************** numpy.where")
indexes = numpy.where(self.labels != label)[0]
numpy.random.shuffle(indexes)
indexes = indexes[
......@@ -228,20 +213,13 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
shape = tuple([len(indexes)] + list(self.shape[1:]))
sample_n = numpy.zeros(shape=shape, dtype='float32')
#logger.info("****************** search")
for i in range(shape[0]):
#logger.info("****************** fetch")
file_name = self.data[indexes[i], ...]
#logger.info("****************** load")
sample_n[i, ...] = self.load_from_file(str(file_name))
sample_n = self.normalize_sample(sample_n)
sample_n[i, ...] = self.normalize_sample(self.load_from_file(str(file_name)))
#logger.info("****************** project")
embedding_n = self.project(sample_n)
distances = []
#logger.info("****************** distances")
for n in embedding_n:
d = euclidean(embedding_a, n)
......@@ -258,5 +236,4 @@ class TripletWithSelectionDisk(Triplet, Disk, OnLineSampling):
if numpy.isinf(distances[index]):
logger.info("SEMI-HARD negative not found, took the first one")
index = 0
#logger.info("****************** return")
return sample_n[index, ...]
......@@ -4,7 +4,8 @@
# @date: Thu 13 Oct 2016 13:35 CEST
import numpy
from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, Disk, SiameseDisk, TripletDisk
from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, Disk, SiameseDisk, TripletDisk, \
TripletWithFastSelectionDisk, TripletWithSelectionDisk
import pkg_resources
from bob.learn.tensorflow.utils import load_mnist
import os
......@@ -15,7 +16,6 @@ Some unit tests for the datashuffler
def get_dummy_files():
base_path = pkg_resources.resource_filename(__name__, 'data/dummy_database')
files = []
clients = []
......@@ -28,7 +28,6 @@ def get_dummy_files():
def test_memory_shuffler():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
......@@ -50,7 +49,6 @@ def test_memory_shuffler():
def test_siamesememory_shuffler():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
......@@ -74,7 +72,6 @@ def test_siamesememory_shuffler():
def test_tripletmemory_shuffler():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
......@@ -98,7 +95,6 @@ def test_tripletmemory_shuffler():
def test_disk_shuffler():
train_data, train_labels = get_dummy_files()
batch_shape = [2, 125, 125, 3]
......@@ -119,7 +115,6 @@ def test_disk_shuffler():
def test_siamesedisk_shuffler():
train_data, train_labels = get_dummy_files()
batch_shape = [2, 125, 125, 3]
......@@ -142,7 +137,6 @@ def test_siamesedisk_shuffler():
def test_tripletdisk_shuffler():
train_data, train_labels = get_dummy_files()
batch_shape = [1, 125, 125, 3]
......@@ -164,3 +158,47 @@ def test_tripletdisk_shuffler():
assert placeholders[2].get_shape().as_list() == batch_shape
def test_triplet_fast_selection_disk_shuffler():
train_data, train_labels = get_dummy_files()
batch_shape = [1, 125, 125, 3]
data_shuffler = TripletWithFastSelectionDisk(train_data, train_labels,
input_shape=batch_shape[1:],
total_identities=1,
batch_size=batch_shape[0])
batch = data_shuffler.get_batch()
assert len(batch) == 3
assert batch[0].shape == tuple(batch_shape)
assert batch[1].shape == tuple(batch_shape)
assert batch[2].shape == tuple(batch_shape)
placeholders = data_shuffler.get_placeholders(name="train")
assert placeholders[0].get_shape().as_list() == batch_shape
assert placeholders[1].get_shape().as_list() == batch_shape
assert placeholders[2].get_shape().as_list() == batch_shape
def test_triplet_selection_disk_shuffler():
train_data, train_labels = get_dummy_files()
batch_shape = [1, 125, 125, 3]
data_shuffler = TripletWithSelectionDisk(train_data, train_labels,
input_shape=batch_shape[1:],
total_identities=1,
batch_size=batch_shape[0])
batch = data_shuffler.get_batch()
assert len(batch) == 3
assert batch[0].shape == tuple(batch_shape)
assert batch[1].shape == tuple(batch_shape)
assert batch[2].shape == tuple(batch_shape)
placeholders = data_shuffler.get_placeholders(name="train")
assert placeholders[0].get_shape().as_list() == batch_shape
assert placeholders[1].get_shape().as_list() == batch_shape
assert placeholders[2].get_shape().as_list() == batch_shape
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment