Skip to content
Snippets Groups Projects
Commit 4ed86791 authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Siamese networks trainer with prefetch #27

parent ccb144b1
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -53,7 +53,11 @@ class Disk(Base):
batch_size=1,
seed=10,
data_augmentation=None,
normalizer=Linear()):
normalizer=Linear(),
prefetch=False,
prefetch_capacity=10,
prefetch_threads=5
):
if isinstance(data, list):
data = numpy.array(data)
......@@ -69,7 +73,10 @@ class Disk(Base):
batch_size=batch_size,
seed=seed,
data_augmentation=data_augmentation,
normalizer=normalizer
normalizer=normalizer,
prefetch=prefetch,
prefetch_capacity=prefetch_capacity,
prefetch_threads=prefetch_threads
)
# Seting the seed
numpy.random.seed(seed)
......
......@@ -33,9 +33,25 @@ class Siamese(Base):
self.data_ph['right'] = tf.placeholder(tf.float32, shape=self.input_shape, name="right")
self.label_ph = tf.placeholder(tf.int64, shape=[None], name="label")
# If prefetch, setup the queue to feed data
if self.prefetch:
raise ValueError("There is no prefetch for siamease networks")
if self.prefetch:
queue = tf.FIFOQueue(capacity=self.prefetch_capacity,
dtypes=[tf.float32, tf.float32, tf.int64],
shapes=[self.input_shape[1:], self.input_shape[1:], []])
self.data_ph_from_queue = dict()
self.data_ph_from_queue['left'] = None
self.data_ph_from_queue['right'] = None
# Fetching the place holders from the queue
self.enqueue_op = queue.enqueue_many([self.data_ph['left'], self.data_ph['right'], self.label_ph])
self.data_ph_from_queue['left'], self.data_ph_from_queue['right'], self.label_ph_from_queue = queue.dequeue_many(self.batch_size)
else:
self.data_ph_from_queue = dict()
self.data_ph_from_queue['left'] = self.data_ph['left']
self.data_ph_from_queue['right'] = self.data_ph['right']
self.label_ph_from_queue = self.label_ph
def get_genuine_or_not(self, input_data, input_labels, genuine=True):
......
......@@ -51,7 +51,11 @@ class SiameseDisk(Siamese, Disk):
batch_size=1,
seed=10,
data_augmentation=None,
normalizer=Linear()):
normalizer=Linear(),
prefetch=False,
prefetch_capacity=10,
prefetch_threads=5
):
if isinstance(data, list):
data = numpy.array(data)
......@@ -67,7 +71,10 @@ class SiameseDisk(Siamese, Disk):
batch_size=batch_size,
seed=seed,
data_augmentation=data_augmentation,
normalizer=normalizer
normalizer=normalizer,
prefetch=prefetch,
prefetch_capacity=prefetch_capacity,
prefetch_threads=prefetch_threads
)
# Seting the seed
numpy.random.seed(seed)
......
......@@ -50,7 +50,10 @@ class SiameseMemory(Siamese, Memory):
batch_size=32,
seed=10,
data_augmentation=None,
normalizer=Linear()
normalizer=Linear(),
prefetch=False,
prefetch_capacity=50,
prefetch_threads=10
):
super(SiameseMemory, self).__init__(
......@@ -61,7 +64,10 @@ class SiameseMemory(Siamese, Memory):
batch_size=batch_size,
seed=seed,
data_augmentation=data_augmentation,
normalizer=normalizer
normalizer=normalizer,
prefetch=prefetch,
prefetch_capacity=prefetch_capacity,
prefetch_threads=prefetch_threads
)
# Seting the seed
numpy.random.seed(seed)
......
......@@ -17,6 +17,7 @@ import bob.io.base
import shutil
from scipy.spatial.distance import cosine
import bob.measure
from .test_cnn import dummy_experiment
"""
Some unit tests for the datashuffler
......@@ -83,3 +84,53 @@ def test_cnn_trainer():
del embedding
tf.reset_default_graph()
def test_siamesecnn_trainer():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
# Creating datashufflers
train_data_shuffler = SiameseMemory(train_data, train_labels,
input_shape=[None, 28, 28, 1],
batch_size=batch_size,
normalizer=ScaleFactor(),
prefetch=True)
validation_data_shuffler = Memory(validation_data, validation_labels,
input_shape=[None, 28, 28, 1],
batch_size=validation_batch_size,
normalizer=ScaleFactor())
directory = "./temp/siamesecnn"
# Preparing the architecture
architecture = Chopra(seed=seed, fc1_output=10)
# Loss for the Siamese
loss = ContrastiveLoss(contrastive_margin=4.)
input_pl = train_data_shuffler("data")
graph = dict()
graph['left'] = architecture(input_pl['left'])
graph['right'] = architecture(input_pl['right'], reuse=True)
trainer = SiameseTrainer(train_data_shuffler,
iterations=iterations,
analizer=None,
temp_dir=directory)
trainer.create_network_from_scratch(graph=graph,
loss=loss,
learning_rate=constant(0.01, name="regular_lr"),
optimizer=tf.train.GradientDescentOptimizer(0.01),)
trainer.train()
embedding = Embedding(validation_data_shuffler("data", from_queue=False),
architecture(validation_data_shuffler("data", from_queue=False), reuse=True))
eer = dummy_experiment(validation_data_shuffler, embedding)
assert eer < 0.15
shutil.rmtree(directory)
del architecture
del trainer # Just to clean tf.variables
tf.reset_default_graph()
......@@ -57,7 +57,7 @@ def validate_network(embedding, validation_data, validation_labels):
def test_cnn_trainer_scratch():
tf.reset_default_graph()
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
......
......@@ -244,7 +244,6 @@ class SiameseTrainer(Trainer):
#self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
def load_and_enqueue(self):
"""
Injecting data in the place holder queue
......@@ -254,12 +253,15 @@ class SiameseTrainer(Trainer):
"""
while not self.thread_pool.should_stop():
[train_data, train_labels] = self.train_data_shuffler.get_batch()
[train_data_left, train_data_right, train_labels] = self.train_data_shuffler.get_batch()
data_ph = self.train_data_shuffler("data", from_queue=False)
data_ph = dict()
data_ph['left'] = self.train_data_shuffler("data", from_queue=False)['left']
data_ph['right'] = self.train_data_shuffler("data", from_queue=False)['right']
label_ph = self.train_data_shuffler("label", from_queue=False)
feed_dict = {data_ph: train_data,
feed_dict = {data_ph['left']: train_data_left,
data_ph['right']: train_data_right,
label_ph: train_labels}
self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment