Commit 60f2c7c1 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Siamese net with tf-slim

parent e94ba466
Pipeline #8158 failed with stages
in 4 minutes and 36 seconds
...@@ -67,7 +67,6 @@ class Chopra(object): ...@@ -67,7 +67,6 @@ class Chopra(object):
device="/cpu:0", device="/cpu:0",
batch_norm=False): batch_norm=False):
self.conv1_kernel_size = conv1_kernel_size self.conv1_kernel_size = conv1_kernel_size
self.conv1_output = conv1_output self.conv1_output = conv1_output
self.pooling1_size = pooling1_size self.pooling1_size = pooling1_size
...@@ -84,7 +83,6 @@ class Chopra(object): ...@@ -84,7 +83,6 @@ class Chopra(object):
def __call__(self, inputs): def __call__(self, inputs):
slim = tf.contrib.slim slim = tf.contrib.slim
initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=self.seed)
with tf.device(self.device): with tf.device(self.device):
...@@ -106,6 +104,6 @@ class Chopra(object): ...@@ -106,6 +104,6 @@ class Chopra(object):
graph = slim.fully_connected(graph, self.fc1_output, graph = slim.fully_connected(graph, self.fc1_output,
weights_initializer=initializer, weights_initializer=initializer,
activation_fn=None,
scope='fc1') scope='fc1')
return graph return graph
...@@ -104,7 +104,6 @@ def test_cnn_trainer(): ...@@ -104,7 +104,6 @@ def test_cnn_trainer():
graph = architecture(inputs['data']) graph = architecture(inputs['data'])
embedding = Embedding(inputs['data'], graph) embedding = Embedding(inputs['data'], graph)
# One graph trainer # One graph trainer
trainer = Trainer(inputs=inputs, trainer = Trainer(inputs=inputs,
graph=graph, graph=graph,
...@@ -113,15 +112,15 @@ def test_cnn_trainer(): ...@@ -113,15 +112,15 @@ def test_cnn_trainer():
analizer=None, analizer=None,
prefetch=False, prefetch=False,
learning_rate=constant(0.01, name="regular_lr"), learning_rate=constant(0.01, name="regular_lr"),
optimizer=tf.train.GradientDescentOptimizer(0.01),
temp_dir=directory temp_dir=directory
) )
trainer.train(train_data_shuffler) trainer.train(train_data_shuffler)
accuracy = validate_network(embedding, validation_data, validation_labels) accuracy = validate_network(embedding, validation_data, validation_labels)
#import ipdb; ipdb.set_trace()
# At least 80% of accuracy # At least 80% of accuracy
assert accuracy > 80. assert accuracy > 80.
#shutil.rmtree(directory) shutil.rmtree(directory)
del trainer del trainer
del graph del graph
...@@ -143,12 +142,21 @@ def test_siamesecnn_trainer(): ...@@ -143,12 +142,21 @@ def test_siamesecnn_trainer():
# Preparing the architecture # Preparing the architecture
architecture = Chopra(seed=seed, fc1_output=10) architecture = Chopra(seed=seed, fc1_output=10)
inputs = {}
inputs['left'] = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name="input_left")
inputs['right'] = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name="input_right")
inputs['label'] = tf.placeholder(tf.int64, shape=[None], name="label")
graph = {}
graph['left'] = architecture(inputs['left'])
graph['right'] = architecture(inputs['right'])
# Loss for the Siamese # Loss for the Siamese
loss = ContrastiveLoss(contrastive_margin=4.) loss = ContrastiveLoss(contrastive_margin=4.)
# One graph trainer # One graph trainer
trainer = SiameseTrainer(architecture=architecture, trainer = SiameseTrainer(inputs=inputs,
graph=graph,
loss=loss, loss=loss,
iterations=iterations, iterations=iterations,
prefetch=False, prefetch=False,
...@@ -158,6 +166,7 @@ def test_siamesecnn_trainer(): ...@@ -158,6 +166,7 @@ def test_siamesecnn_trainer():
temp_dir=directory temp_dir=directory
) )
import ipdb; ipdb.set_trace();
trainer.train(train_data_shuffler) trainer.train(train_data_shuffler)
eer = dummy_experiment(validation_data_shuffler, architecture) eer = dummy_experiment(validation_data_shuffler, architecture)
......
...@@ -8,6 +8,7 @@ from tensorflow.core.framework import summary_pb2 ...@@ -8,6 +8,7 @@ from tensorflow.core.framework import summary_pb2
from ..analyzers import ExperimentAnalizer, SoftmaxAnalizer from ..analyzers import ExperimentAnalizer, SoftmaxAnalizer
from ..network import SequenceNetwork from ..network import SequenceNetwork
from .Trainer import Trainer from .Trainer import Trainer
from .learning_rate import constant
import os import os
import logging import logging
logger = logging.getLogger("bob.learn") logger = logging.getLogger("bob.learn")
...@@ -60,7 +61,8 @@ class SiameseTrainer(Trainer): ...@@ -60,7 +61,8 @@ class SiameseTrainer(Trainer):
""" """
def __init__(self, def __init__(self,
architecture, inputs,
graph,
optimizer=tf.train.AdamOptimizer(), optimizer=tf.train.AdamOptimizer(),
use_gpu=False, use_gpu=False,
loss=None, loss=None,
...@@ -84,30 +86,56 @@ class SiameseTrainer(Trainer): ...@@ -84,30 +86,56 @@ class SiameseTrainer(Trainer):
verbosity_level=2 verbosity_level=2
): ):
super(SiameseTrainer, self).__init__( import ipdb;
architecture=architecture, ipdb.set_trace();
optimizer=optimizer,
use_gpu=use_gpu,
loss=loss,
temp_dir=temp_dir,
# Learning rate self.inputs = inputs
learning_rate=learning_rate, self.graph = graph
self.loss = loss
###### training options ########## if not isinstance(self.graph, dict) or not(('left' and 'right') in self.graph.keys()):
convergence_threshold=convergence_threshold, raise ValueError("Expected a dict with the elements `right` and `left` as input for the keywork `graph`")
iterations=iterations,
snapshot=snapshot,
validation_snapshot=validation_snapshot,
prefetch=prefetch,
## Analizer self.predictor = self.loss(self.graph, inputs['label'])
analizer=analizer,
model_from_file=model_from_file, self.optimizer_class = optimizer
self.use_gpu = use_gpu
self.temp_dir = temp_dir
verbosity_level=verbosity_level if learning_rate is None and model_from_file == "":
) self.learning_rate = constant()
else:
self.learning_rate = learning_rate
self.iterations = iterations
self.snapshot = snapshot
self.validation_snapshot = validation_snapshot
self.convergence_threshold = convergence_threshold
self.prefetch = prefetch
# Training variables used in the fit
self.optimizer = None
self.training_graph = None
self.train_data_shuffler = None
self.summaries_train = None
self.train_summary_writter = None
self.thread_pool = None
# Validation data
self.validation_graph = None
self.validation_summary_writter = None
# Analizer
self.analizer = analizer
self.thread_pool = None
self.enqueue_op = None
self.global_step = None
self.model_from_file = model_from_file
self.session = None
bob.core.log.set_verbosity_level(logger, verbosity_level)
self.between_class_graph_train = None self.between_class_graph_train = None
self.within_class_graph_train = None self.within_class_graph_train = None
...@@ -115,6 +143,7 @@ class SiameseTrainer(Trainer): ...@@ -115,6 +143,7 @@ class SiameseTrainer(Trainer):
self.between_class_graph_validation = None self.between_class_graph_validation = None
self.within_class_graph_validation = None self.within_class_graph_validation = None
def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler): def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
""" """
Persist the placeholders Persist the placeholders
......
...@@ -94,12 +94,13 @@ class Trainer(object): ...@@ -94,12 +94,13 @@ class Trainer(object):
verbosity_level=2): verbosity_level=2):
#if not isinstance(graph, SequenceNetwork):
# raise ValueError("`architecture` should be instance of `SequenceNetwork`")
self.inputs = inputs self.inputs = inputs
self.graph = graph self.graph = graph
self.loss = loss self.loss = loss
if not isinstance(self.graph, tf.Tensor):
raise ValueError("Expected a tf.Tensor as input for the keywork `graph`")
self.predictor = self.loss(self.graph, inputs['label']) self.predictor = self.loss(self.graph, inputs['label'])
self.optimizer_class = optimizer self.optimizer_class = optimizer
...@@ -325,10 +326,6 @@ class Trainer(object): ...@@ -325,10 +326,6 @@ class Trainer(object):
self.train_data_shuffler = train_data_shuffler self.train_data_shuffler = train_data_shuffler
logger.info("Initializing !!") logger.info("Initializing !!")
if not isinstance(self.graph, tf.Tensor):
raise NotImplemented("Not tensor still not implemented")
self.session = Session.instance(new=True).session self.session = Session.instance(new=True).session
# Loading a pretrained model # Loading a pretrained model
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment