Commit 144ce19b authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Triplet networks trainer with prefetch #27

parent 4ed86791
Pipeline #11264 canceled with stages
in 1 minute and 59 seconds
......@@ -33,24 +33,24 @@ class Siamese(Base):
self.data_ph['right'] = tf.placeholder(tf.float32, shape=self.input_shape, name="right")
self.label_ph = tf.placeholder(tf.int64, shape=[None], name="label")
if self.prefetch:
queue = tf.FIFOQueue(capacity=self.prefetch_capacity,
dtypes=[tf.float32, tf.float32, tf.int64],
shapes=[self.input_shape[1:], self.input_shape[1:], []])
self.data_ph_from_queue = dict()
self.data_ph_from_queue['left'] = None
self.data_ph_from_queue['right'] = None
# Fetching the place holders from the queue
self.enqueue_op = queue.enqueue_many([self.data_ph['left'], self.data_ph['right'], self.label_ph])
self.data_ph_from_queue['left'], self.data_ph_from_queue['right'], self.label_ph_from_queue = queue.dequeue_many(self.batch_size)
else:
self.data_ph_from_queue = dict()
self.data_ph_from_queue['left'] = self.data_ph['left']
self.data_ph_from_queue['right'] = self.data_ph['right']
self.label_ph_from_queue = self.label_ph
if self.prefetch:
queue = tf.FIFOQueue(capacity=self.prefetch_capacity,
dtypes=[tf.float32, tf.float32, tf.int64],
shapes=[self.input_shape[1:], self.input_shape[1:], []])
self.data_ph_from_queue = dict()
self.data_ph_from_queue['left'] = None
self.data_ph_from_queue['right'] = None
# Fetching the place holders from the queue
self.enqueue_op = queue.enqueue_many([self.data_ph['left'], self.data_ph['right'], self.label_ph])
self.data_ph_from_queue['left'], self.data_ph_from_queue['right'], self.label_ph_from_queue = queue.dequeue_many(self.batch_size)
else:
self.data_ph_from_queue = dict()
self.data_ph_from_queue['left'] = self.data_ph['left']
self.data_ph_from_queue['right'] = self.data_ph['right']
self.label_ph_from_queue = self.label_ph
def get_genuine_or_not(self, input_data, input_labels, genuine=True):
......
......@@ -33,10 +33,25 @@ class Triplet(Base):
self.data_ph['positive'] = tf.placeholder(tf.float32, shape=self.input_shape, name="positive")
self.data_ph['negative'] = tf.placeholder(tf.float32, shape=self.input_shape, name="negative")
# If prefetch, setup the queue to feed data
if self.prefetch:
raise ValueError("There is no prefetch for siamease networks")
queue = tf.FIFOQueue(capacity=self.prefetch_capacity,
dtypes=[tf.float32, tf.float32, tf.float32],
shapes=[self.input_shape[1:], self.input_shape[1:], self.input_shape[1:]])
self.data_ph_from_queue = dict()
self.data_ph_from_queue['anchor'] = None
self.data_ph_from_queue['positive'] = None
self.data_ph_from_queue['negative'] = None
# Fetching the place holders from the queue
self.enqueue_op = queue.enqueue_many([self.data_ph['anchor'], self.data_ph['positive'], self.data_ph['negative']])
self.data_ph_from_queue['anchor'], self.data_ph_from_queue['positive'], self.data_ph_from_queue['negative'] = queue.dequeue_many(self.batch_size)
else:
self.data_ph_from_queue = dict()
self.data_ph_from_queue['anchor'] = self.data_ph['anchor']
self.data_ph_from_queue['positive'] = self.data_ph['positive']
self.data_ph_from_queue['negative'] = self.data_ph['negative']
def get_one_triplet(self, input_data, input_labels):
# Getting a pair of clients
......
......@@ -56,7 +56,11 @@ class TripletDisk(Triplet, Disk):
batch_size=1,
seed=10,
data_augmentation=None,
normalizer=Linear()):
normalizer=Linear(),
prefetch=False,
prefetch_capacity=50,
prefetch_threads=10
):
if isinstance(data, list):
data = numpy.array(data)
......@@ -71,7 +75,10 @@ class TripletDisk(Triplet, Disk):
input_dtype=input_dtype,
batch_size=batch_size,
data_augmentation=data_augmentation,
normalizer=normalizer
normalizer=normalizer,
prefetch=prefetch,
prefetch_capacity=prefetch_capacity,
prefetch_threads=prefetch_threads
)
# Seting the seed
numpy.random.seed(seed)
......
......@@ -50,7 +50,11 @@ class TripletMemory(Triplet, Memory):
batch_size=1,
seed=10,
data_augmentation=None,
normalizer=Linear()):
normalizer=Linear(),
prefetch=False,
prefetch_capacity=50,
prefetch_threads=10
):
super(TripletMemory, self).__init__(
data=data,
......@@ -60,7 +64,10 @@ class TripletMemory(Triplet, Memory):
batch_size=batch_size,
seed=seed,
data_augmentation=data_augmentation,
normalizer=normalizer
normalizer=normalizer,
prefetch=prefetch,
prefetch_capacity=prefetch_capacity,
prefetch_threads=prefetch_threads
)
# Seting the seed
numpy.random.seed(seed)
......
......@@ -110,7 +110,7 @@ def test_siamesecnn_trainer():
# Loss for the Siamese
loss = ContrastiveLoss(contrastive_margin=4.)
input_pl = train_data_shuffler("data")
input_pl = train_data_shuffler("data", from_queue=True)
graph = dict()
graph['left'] = architecture(input_pl['left'])
graph['right'] = architecture(input_pl['right'], reuse=True)
......@@ -134,3 +134,55 @@ def test_siamesecnn_trainer():
del architecture
del trainer # Just to clean tf.variables
tf.reset_default_graph()
def test_tripletcnn_trainer():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
# Creating datashufflers
train_data_shuffler = TripletMemory(train_data, train_labels,
input_shape=[None, 28, 28, 1],
batch_size=batch_size,
normalizer=ScaleFactor(),
prefetch=True)
validation_data_shuffler = Memory(validation_data, validation_labels,
input_shape=[None, 28, 28, 1],
batch_size=validation_batch_size,
normalizer=ScaleFactor())
directory = "./temp/tripletcnn"
# Preparing the architecture
architecture = Chopra(seed=seed, fc1_output=10)
# Loss for the Siamese
loss = TripletLoss(margin=4.)
input_pl = train_data_shuffler("data", from_queue=True)
graph = dict()
graph['anchor'] = architecture(input_pl['anchor'])
graph['positive'] = architecture(input_pl['positive'], reuse=True)
graph['negative'] = architecture(input_pl['negative'], reuse=True)
# One graph trainer
trainer = TripletTrainer(train_data_shuffler,
iterations=iterations,
analizer=None,
temp_dir=directory
)
trainer.create_network_from_scratch(graph=graph,
loss=loss,
learning_rate=constant(0.01, name="regular_lr"),
optimizer=tf.train.GradientDescentOptimizer(0.01),)
trainer.train(train_data_shuffler)
embedding = Embedding(train_data_shuffler("data", from_queue=False)['anchor'], graph['anchor'])
eer = dummy_experiment(validation_data_shuffler, embedding)
assert eer < 0.15
shutil.rmtree(directory)
del architecture
del trainer # Just to clean tf.variables
tf.reset_default_graph()
......@@ -238,3 +238,24 @@ class TripletTrainer(Trainer):
tf.summary.scalar('lr', self.learning_rate)
return tf.summary.merge_all()
def load_and_enqueue(self):
"""
Injecting data in the place holder queue
**Parameters**
session: Tensorflow session
"""
while not self.thread_pool.should_stop():
[train_data_anchor, train_data_positive, train_data_negative] = self.train_data_shuffler.get_batch()
data_ph = dict()
data_ph['anchor'] = self.train_data_shuffler("data", from_queue=False)['anchor']
data_ph['positive'] = self.train_data_shuffler("data", from_queue=False)['positive']
data_ph['negative'] = self.train_data_shuffler("data", from_queue=False)['negative']
feed_dict = {data_ph['anchor']: train_data_anchor,
data_ph['positive']: train_data_positive,
data_ph['negative']: train_data_negative}
self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment