Skip to content
Snippets Groups Projects
Commit 4afda035 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

New triplet selection algorithm

parent 88d13aff
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import numpy
import tensorflow as tf
from .Disk import Disk
from .Triplet import Triplet
from .OnlineSampling import OnLineSampling
from scipy.spatial.distance import euclidean, cdist
import logging
logger = logging.getLogger("bob.learn.tensorflow")
class TripletWithFastSelectionDisk(Triplet, Disk, OnLineSampling):
"""
This data shuffler generates triplets from :py:class:`bob.learn.tensorflow.datashuffler.Memory` shufflers.
The selection of the triplets is inspired in the paper:
Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
"Facenet: A unified embedding for face recognition and clustering." Proceedings of the IEEE Conference on
Computer Vision and Pattern Recognition. 2015.
In this shuffler, the triplets are selected as the following:
1. Select M identities
2. Get N pairs anchor-positive (for each M identities) such that the argmax(anchor, positive)
3. For each pair anchor-positive, find the "semi-hard" negative samples such that
argmin(||f(x_a) - f(x_p)||^2 < ||f(x_a) - f(x_n)||^2
**Parameters**
data:
labels:
perc_train:
scale:
train_batch_size:
validation_batch_size:
data_augmentation:
total_identities: Number of identities inside of the batch
"""
def __init__(self, data, labels,
input_shape,
input_dtype="float64",
scale=True,
batch_size=1,
seed=10,
data_augmentation=None,
total_identities=10):
super(TripletWithFastSelectionDisk, self).__init__(
data=data,
labels=labels,
input_shape=input_shape,
input_dtype=input_dtype,
scale=scale,
batch_size=batch_size,
seed=seed,
data_augmentation=data_augmentation
)
self.clear_variables()
# Seting the seed
numpy.random.seed(seed)
self.total_identities = total_identities
self.first_batch = True
# For the negative search I'll load `N` times the batch
self.batch_increase_factor = 4
def get_random_batch(self):
"""
Get a random triplet
**Parameters**
is_target_set_train: Defining the target set to get the batch
**Return**
"""
sample_a = numpy.zeros(shape=self.shape, dtype='float32')
sample_p = numpy.zeros(shape=self.shape, dtype='float32')
sample_n = numpy.zeros(shape=self.shape, dtype='float32')
for i in range(self.shape[0]):
file_name_a, file_name_p, file_name_n = self.get_one_triplet(self.data, self.labels)
sample_a[i, ...] = self.load_from_file(str(file_name_a))
sample_p[i, ...] = self.load_from_file(str(file_name_p))
sample_n[i, ...] = self.load_from_file(str(file_name_n))
if self.scale:
sample_a *= self.scale_value
sample_p *= self.scale_value
sample_n *= self.scale_value
return [sample_a, sample_p, sample_n]
def get_batch(self):
"""
Get SELECTED triplets
**Parameters**
is_target_set_train: Defining the target set to get the batch
**Return**
"""
if self.first_batch:
self.first_batch = False
return self.get_random_batch()
# Selecting the classes used in the selection
indexes = numpy.random.choice(len(self.possible_labels), self.total_identities, replace=False)
samples_per_identity = numpy.ceil(self.batch_size/float(self.total_identities))
anchor_labels = numpy.ones(samples_per_identity) * self.possible_labels[indexes[0]]
for i in range(1, self.total_identities):
anchor_labels = numpy.hstack((anchor_labels,numpy.ones(samples_per_identity) * self.possible_labels[indexes[i]]))
anchor_labels = anchor_labels[0:self.batch_size]
samples_a = numpy.zeros(shape=self.shape, dtype='float32')
# Computing the embedding
for i in range(self.shape[0]):
samples_a[i, ...] = self.get_anchor(anchor_labels[i])
embedding_a = self.project(samples_a)
# Getting the positives
samples_p, embedding_p, d_anchor_positive = self.get_positives(anchor_labels, embedding_a)
samples_n = self.get_negative(anchor_labels, embedding_a, d_anchor_positive)
#import bob.io.base
#import bob.io.image
#for i in range(self.shape[0]):
#bob.io.base.save((self.skimage2bob(samples_a[i, ...]) / 0.00390625).astype("uint8"), "{0}_a.jpg".format(i))
#bob.io.base.save((self.skimage2bob(samples_p[i, ...]) / 0.00390625).astype("uint8"), "{0}_p.jpg".format(i))
#bob.io.base.save((self.skimage2bob(samples_n[i, ...]) / 0.00390625).astype("uint8"), "{0}_n.jpg".format(i))
return samples_a, samples_p, samples_n
def get_anchor(self, label):
"""
Select random samples as anchor
"""
# Getting the indexes of the data from a particular client
indexes = numpy.where(self.labels == label)[0]
numpy.random.shuffle(indexes)
file_name = self.data[indexes[0], ...]
anchor = self.load_from_file(str(file_name))
if self.scale:
anchor *= self.scale_value
return anchor
def get_positives(self, anchor_labels, embedding_a):
"""
Get the a random set of positive pairs
"""
samples_p = numpy.zeros(shape=self.shape, dtype='float32')
for i in range(self.shape[0]):
l = anchor_labels[i]
indexes = numpy.where(self.labels == l)[0]
numpy.random.shuffle(indexes)
file_name = self.data[indexes[0], ...]
samples_p[i, ...] = self.load_from_file(str(file_name))
if self.scale:
samples_p *= self.scale_value
embedding_p = self.project(samples_p)
# Computing the distances
d_anchor_positive = []
for i in range(self.shape[0]):
d_anchor_positive.append(euclidean(embedding_a[i, :], embedding_p[i, :]))
return samples_p, embedding_p, d_anchor_positive
def get_negative(self, anchor_labels, embedding_a, d_anchor_positive):
"""
Get the the semi-hard negative
"""
# Shuffling all the dataset
indexes = range(len(self.labels))
numpy.random.shuffle(indexes)
negative_samples_search = self.batch_size*self.batch_increase_factor
# Limiting to the batch size, otherwise the number of comparisons will explode
indexes = indexes[0:negative_samples_search]
# Loading samples for the semi-hard search
shape = tuple([len(indexes)] + list(self.shape[1:]))
temp_samples_n = numpy.zeros(shape=shape, dtype='float32')
samples_n = numpy.zeros(shape=self.shape, dtype='float32')
for i in range(shape[0]):
file_name = self.data[indexes[i], ...]
temp_samples_n[i, ...] = self.load_from_file(str(file_name))
if self.scale:
temp_samples_n *= self.scale_value
# Computing all the embeddings
embedding_temp_n = self.project(temp_samples_n)
# Computing the distances
d_anchor_negative = cdist(embedding_a, embedding_temp_n, metric='euclidean')
# Selecting the negative samples
for i in range(self.shape[0]):
label = anchor_labels[i]
possible_candidates = [d if d > d_anchor_positive[i] else numpy.inf for d in d_anchor_negative[i]]
for j in numpy.argsort(possible_candidates):
# Checking if they don't have the same label
if indexes[j] != label:
samples_n[i, ...] = temp_samples_n[j, ...]
if numpy.isinf(possible_candidates[j]):
logger.info("SEMI-HARD negative not found, took the first one")
break
return samples_n
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment