From 33f090cb75cd40a7c1479630423e1ee87f5c2865 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Sun, 28 Aug 2016 21:37:49 +0200 Subject: [PATCH] Testing the new machine concept --- bob/learn/tensorflow/analyzers/Analizer.py | 6 +++--- .../tensorflow/network/SequenceNetwork.py | 20 ++++++++++++++++++- .../tensorflow/script/train_mnist_siamese.py | 2 +- .../tensorflow/trainers/SiameseTrainer.py | 3 +-- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/bob/learn/tensorflow/analyzers/Analizer.py b/bob/learn/tensorflow/analyzers/Analizer.py index 27ea52f0..70f54f29 100644 --- a/bob/learn/tensorflow/analyzers/Analizer.py +++ b/bob/learn/tensorflow/analyzers/Analizer.py @@ -21,7 +21,7 @@ class Analizer: """ - def __init__(self, data_shuffler, graph, feature_placeholder, session): + def __init__(self, data_shuffler, machine, feature_placeholder, session): """ Use the CNN as feature extractor for a n-class classification @@ -33,7 +33,7 @@ class Analizer: """ self.data_shuffler = data_shuffler - self.graph = graph + self.machine = machine self.feature_placeholder = feature_placeholder self.session = session @@ -47,7 +47,7 @@ class Analizer: data, labels = self.data_shuffler.get_batch(train_dataset=False) feed_dict = {self.feature_placeholder: data} - return self.session.run([self.graph], feed_dict=feed_dict)[0], labels + return self.machine(feed_dict, self.session) def __call__(self): diff --git a/bob/learn/tensorflow/network/SequenceNetwork.py b/bob/learn/tensorflow/network/SequenceNetwork.py index 2207360f..93452b01 100644 --- a/bob/learn/tensorflow/network/SequenceNetwork.py +++ b/bob/learn/tensorflow/network/SequenceNetwork.py @@ -26,18 +26,29 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): Base constructor **Parameters** - input: Place Holder + feature_layer: """ self.sequence_net = OrderedDict() self.feature_layer = feature_layer def add(self, layer): + """ + Add a layer in the sequence network + + """ if not isinstance(layer, Layer): raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`") self.sequence_net[layer.name] = layer def compute_graph(self, input_data, cut=False): + """ + Given the current network, return the Tensorflow graph + + **Parameter** + input_data: + cut: + """ input_offset = input_data for k in self.sequence_net.keys(): @@ -49,3 +60,10 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): return input_offset return input_offset + + def compute_projection_graph(self, placeholder): + return self.compute_graph(placeholder, cut=True) + + def __call__(self, feed_dict, session): + #placeholder + return session.run([self.graph], feed_dict=feed_dict)[0] diff --git a/bob/learn/tensorflow/script/train_mnist_siamese.py b/bob/learn/tensorflow/script/train_mnist_siamese.py index 42804b9f..98a83bc2 100644 --- a/bob/learn/tensorflow/script/train_mnist_siamese.py +++ b/bob/learn/tensorflow/script/train_mnist_siamese.py @@ -44,7 +44,7 @@ def main(): data_shuffler = PairDataShuffler(data, labels) # Preparing the architecture - lenet = Lenet(feature_layer="fc1") + lenet = Lenet(feature_layer="fc2") loss = ContrastiveLoss() trainer = SiameseTrainer(architecture=lenet, loss=loss, iterations=ITERATIONS, base_lr=0.00001) diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py index 7b7adf5d..a6629d06 100644 --- a/bob/learn/tensorflow/trainers/SiameseTrainer.py +++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py @@ -63,7 +63,6 @@ class SiameseTrainer(object): train_left_graph = self.architecture.compute_graph(train_placeholder_left_data) train_right_graph = self.architecture.compute_graph(train_placeholder_right_data) - feature_graph = self.architecture.compute_graph(feature_placeholder, cut=True) loss_train, within_class, between_class = self.loss(train_placeholder_labels, train_left_graph, @@ -88,7 +87,7 @@ class SiameseTrainer(object): print("Initializing !!") # Training with tf.Session() as session: - analizer = Analizer(data_shuffler, feature_graph, feature_placeholder, session) + analizer = Analizer(data_shuffler, self.architecture, feature_placeholder, session) train_writer = tf.train.SummaryWriter('./LOGS/train', session.graph) -- GitLab