Skip to content
Snippets Groups Projects
Commit 60f2c7c1 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Siamese net with tf-slim

parent e94ba466
Branches
Tags
No related merge requests found
Pipeline #
......@@ -67,7 +67,6 @@ class Chopra(object):
device="/cpu:0",
batch_norm=False):
self.conv1_kernel_size = conv1_kernel_size
self.conv1_output = conv1_output
self.pooling1_size = pooling1_size
......@@ -84,7 +83,6 @@ class Chopra(object):
def __call__(self, inputs):
slim = tf.contrib.slim
initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=self.seed)
with tf.device(self.device):
......@@ -106,6 +104,6 @@ class Chopra(object):
graph = slim.fully_connected(graph, self.fc1_output,
weights_initializer=initializer,
activation_fn=None,
scope='fc1')
return graph
......@@ -104,7 +104,6 @@ def test_cnn_trainer():
graph = architecture(inputs['data'])
embedding = Embedding(inputs['data'], graph)
# One graph trainer
trainer = Trainer(inputs=inputs,
graph=graph,
......@@ -113,15 +112,15 @@ def test_cnn_trainer():
analizer=None,
prefetch=False,
learning_rate=constant(0.01, name="regular_lr"),
optimizer=tf.train.GradientDescentOptimizer(0.01),
temp_dir=directory
)
trainer.train(train_data_shuffler)
accuracy = validate_network(embedding, validation_data, validation_labels)
#import ipdb; ipdb.set_trace()
# At least 80% of accuracy
assert accuracy > 80.
#shutil.rmtree(directory)
shutil.rmtree(directory)
del trainer
del graph
......@@ -143,12 +142,21 @@ def test_siamesecnn_trainer():
# Preparing the architecture
architecture = Chopra(seed=seed, fc1_output=10)
inputs = {}
inputs['left'] = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name="input_left")
inputs['right'] = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name="input_right")
inputs['label'] = tf.placeholder(tf.int64, shape=[None], name="label")
graph = {}
graph['left'] = architecture(inputs['left'])
graph['right'] = architecture(inputs['right'])
# Loss for the Siamese
loss = ContrastiveLoss(contrastive_margin=4.)
# One graph trainer
trainer = SiameseTrainer(architecture=architecture,
trainer = SiameseTrainer(inputs=inputs,
graph=graph,
loss=loss,
iterations=iterations,
prefetch=False,
......@@ -158,6 +166,7 @@ def test_siamesecnn_trainer():
temp_dir=directory
)
import ipdb; ipdb.set_trace();
trainer.train(train_data_shuffler)
eer = dummy_experiment(validation_data_shuffler, architecture)
......
......@@ -8,6 +8,7 @@ from tensorflow.core.framework import summary_pb2
from ..analyzers import ExperimentAnalizer, SoftmaxAnalizer
from ..network import SequenceNetwork
from .Trainer import Trainer
from .learning_rate import constant
import os
import logging
logger = logging.getLogger("bob.learn")
......@@ -60,7 +61,8 @@ class SiameseTrainer(Trainer):
"""
def __init__(self,
architecture,
inputs,
graph,
optimizer=tf.train.AdamOptimizer(),
use_gpu=False,
loss=None,
......@@ -84,30 +86,56 @@ class SiameseTrainer(Trainer):
verbosity_level=2
):
super(SiameseTrainer, self).__init__(
architecture=architecture,
optimizer=optimizer,
use_gpu=use_gpu,
loss=loss,
temp_dir=temp_dir,
import ipdb;
ipdb.set_trace();
# Learning rate
learning_rate=learning_rate,
self.inputs = inputs
self.graph = graph
self.loss = loss
###### training options ##########
convergence_threshold=convergence_threshold,
iterations=iterations,
snapshot=snapshot,
validation_snapshot=validation_snapshot,
prefetch=prefetch,
if not isinstance(self.graph, dict) or not(('left' and 'right') in self.graph.keys()):
raise ValueError("Expected a dict with the elements `right` and `left` as input for the keywork `graph`")
## Analizer
analizer=analizer,
self.predictor = self.loss(self.graph, inputs['label'])
model_from_file=model_from_file,
self.optimizer_class = optimizer
self.use_gpu = use_gpu
self.temp_dir = temp_dir
verbosity_level=verbosity_level
)
if learning_rate is None and model_from_file == "":
self.learning_rate = constant()
else:
self.learning_rate = learning_rate
self.iterations = iterations
self.snapshot = snapshot
self.validation_snapshot = validation_snapshot
self.convergence_threshold = convergence_threshold
self.prefetch = prefetch
# Training variables used in the fit
self.optimizer = None
self.training_graph = None
self.train_data_shuffler = None
self.summaries_train = None
self.train_summary_writter = None
self.thread_pool = None
# Validation data
self.validation_graph = None
self.validation_summary_writter = None
# Analizer
self.analizer = analizer
self.thread_pool = None
self.enqueue_op = None
self.global_step = None
self.model_from_file = model_from_file
self.session = None
bob.core.log.set_verbosity_level(logger, verbosity_level)
self.between_class_graph_train = None
self.within_class_graph_train = None
......@@ -115,6 +143,7 @@ class SiameseTrainer(Trainer):
self.between_class_graph_validation = None
self.within_class_graph_validation = None
def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
"""
Persist the placeholders
......
......@@ -94,12 +94,13 @@ class Trainer(object):
verbosity_level=2):
#if not isinstance(graph, SequenceNetwork):
# raise ValueError("`architecture` should be instance of `SequenceNetwork`")
self.inputs = inputs
self.graph = graph
self.loss = loss
if not isinstance(self.graph, tf.Tensor):
raise ValueError("Expected a tf.Tensor as input for the keywork `graph`")
self.predictor = self.loss(self.graph, inputs['label'])
self.optimizer_class = optimizer
......@@ -325,10 +326,6 @@ class Trainer(object):
self.train_data_shuffler = train_data_shuffler
logger.info("Initializing !!")
if not isinstance(self.graph, tf.Tensor):
raise NotImplemented("Not tensor still not implemented")
self.session = Session.instance(new=True).session
# Loading a pretrained model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment