Skip to content
Snippets Groups Projects
Commit 33f090cb authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Testing the new machine concept

parent ef1c03ce
Branches
Tags
No related merge requests found
...@@ -21,7 +21,7 @@ class Analizer: ...@@ -21,7 +21,7 @@ class Analizer:
""" """
def __init__(self, data_shuffler, graph, feature_placeholder, session): def __init__(self, data_shuffler, machine, feature_placeholder, session):
""" """
Use the CNN as feature extractor for a n-class classification Use the CNN as feature extractor for a n-class classification
...@@ -33,7 +33,7 @@ class Analizer: ...@@ -33,7 +33,7 @@ class Analizer:
""" """
self.data_shuffler = data_shuffler self.data_shuffler = data_shuffler
self.graph = graph self.machine = machine
self.feature_placeholder = feature_placeholder self.feature_placeholder = feature_placeholder
self.session = session self.session = session
...@@ -47,7 +47,7 @@ class Analizer: ...@@ -47,7 +47,7 @@ class Analizer:
data, labels = self.data_shuffler.get_batch(train_dataset=False) data, labels = self.data_shuffler.get_batch(train_dataset=False)
feed_dict = {self.feature_placeholder: data} feed_dict = {self.feature_placeholder: data}
return self.session.run([self.graph], feed_dict=feed_dict)[0], labels return self.machine(feed_dict, self.session)
def __call__(self): def __call__(self):
......
...@@ -26,18 +26,29 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -26,18 +26,29 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
Base constructor Base constructor
**Parameters** **Parameters**
input: Place Holder feature_layer:
""" """
self.sequence_net = OrderedDict() self.sequence_net = OrderedDict()
self.feature_layer = feature_layer self.feature_layer = feature_layer
def add(self, layer): def add(self, layer):
"""
Add a layer in the sequence network
"""
if not isinstance(layer, Layer): if not isinstance(layer, Layer):
raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`") raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`")
self.sequence_net[layer.name] = layer self.sequence_net[layer.name] = layer
def compute_graph(self, input_data, cut=False): def compute_graph(self, input_data, cut=False):
"""
Given the current network, return the Tensorflow graph
**Parameter**
input_data:
cut:
"""
input_offset = input_data input_offset = input_data
for k in self.sequence_net.keys(): for k in self.sequence_net.keys():
...@@ -49,3 +60,10 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -49,3 +60,10 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
return input_offset return input_offset
return input_offset return input_offset
def compute_projection_graph(self, placeholder):
return self.compute_graph(placeholder, cut=True)
def __call__(self, feed_dict, session):
#placeholder
return session.run([self.graph], feed_dict=feed_dict)[0]
...@@ -44,7 +44,7 @@ def main(): ...@@ -44,7 +44,7 @@ def main():
data_shuffler = PairDataShuffler(data, labels) data_shuffler = PairDataShuffler(data, labels)
# Preparing the architecture # Preparing the architecture
lenet = Lenet(feature_layer="fc1") lenet = Lenet(feature_layer="fc2")
loss = ContrastiveLoss() loss = ContrastiveLoss()
trainer = SiameseTrainer(architecture=lenet, loss=loss, iterations=ITERATIONS, base_lr=0.00001) trainer = SiameseTrainer(architecture=lenet, loss=loss, iterations=ITERATIONS, base_lr=0.00001)
......
...@@ -63,7 +63,6 @@ class SiameseTrainer(object): ...@@ -63,7 +63,6 @@ class SiameseTrainer(object):
train_left_graph = self.architecture.compute_graph(train_placeholder_left_data) train_left_graph = self.architecture.compute_graph(train_placeholder_left_data)
train_right_graph = self.architecture.compute_graph(train_placeholder_right_data) train_right_graph = self.architecture.compute_graph(train_placeholder_right_data)
feature_graph = self.architecture.compute_graph(feature_placeholder, cut=True)
loss_train, within_class, between_class = self.loss(train_placeholder_labels, loss_train, within_class, between_class = self.loss(train_placeholder_labels,
train_left_graph, train_left_graph,
...@@ -88,7 +87,7 @@ class SiameseTrainer(object): ...@@ -88,7 +87,7 @@ class SiameseTrainer(object):
print("Initializing !!") print("Initializing !!")
# Training # Training
with tf.Session() as session: with tf.Session() as session:
analizer = Analizer(data_shuffler, feature_graph, feature_placeholder, session) analizer = Analizer(data_shuffler, self.architecture, feature_placeholder, session)
train_writer = tf.train.SummaryWriter('./LOGS/train', train_writer = tf.train.SummaryWriter('./LOGS/train',
session.graph) session.graph)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment