Skip to content
Snippets Groups Projects
Commit 33f090cb authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Testing the new machine concept

parent ef1c03ce
Branches
Tags
No related merge requests found
......@@ -21,7 +21,7 @@ class Analizer:
"""
def __init__(self, data_shuffler, graph, feature_placeholder, session):
def __init__(self, data_shuffler, machine, feature_placeholder, session):
"""
Use the CNN as feature extractor for a n-class classification
......@@ -33,7 +33,7 @@ class Analizer:
"""
self.data_shuffler = data_shuffler
self.graph = graph
self.machine = machine
self.feature_placeholder = feature_placeholder
self.session = session
......@@ -47,7 +47,7 @@ class Analizer:
data, labels = self.data_shuffler.get_batch(train_dataset=False)
feed_dict = {self.feature_placeholder: data}
return self.session.run([self.graph], feed_dict=feed_dict)[0], labels
return self.machine(feed_dict, self.session)
def __call__(self):
......
......@@ -26,18 +26,29 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
Base constructor
**Parameters**
input: Place Holder
feature_layer:
"""
self.sequence_net = OrderedDict()
self.feature_layer = feature_layer
def add(self, layer):
"""
Add a layer in the sequence network
"""
if not isinstance(layer, Layer):
raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`")
self.sequence_net[layer.name] = layer
def compute_graph(self, input_data, cut=False):
"""
Given the current network, return the Tensorflow graph
**Parameter**
input_data:
cut:
"""
input_offset = input_data
for k in self.sequence_net.keys():
......@@ -49,3 +60,10 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
return input_offset
return input_offset
def compute_projection_graph(self, placeholder):
return self.compute_graph(placeholder, cut=True)
def __call__(self, feed_dict, session):
#placeholder
return session.run([self.graph], feed_dict=feed_dict)[0]
......@@ -44,7 +44,7 @@ def main():
data_shuffler = PairDataShuffler(data, labels)
# Preparing the architecture
lenet = Lenet(feature_layer="fc1")
lenet = Lenet(feature_layer="fc2")
loss = ContrastiveLoss()
trainer = SiameseTrainer(architecture=lenet, loss=loss, iterations=ITERATIONS, base_lr=0.00001)
......
......@@ -63,7 +63,6 @@ class SiameseTrainer(object):
train_left_graph = self.architecture.compute_graph(train_placeholder_left_data)
train_right_graph = self.architecture.compute_graph(train_placeholder_right_data)
feature_graph = self.architecture.compute_graph(feature_placeholder, cut=True)
loss_train, within_class, between_class = self.loss(train_placeholder_labels,
train_left_graph,
......@@ -88,7 +87,7 @@ class SiameseTrainer(object):
print("Initializing !!")
# Training
with tf.Session() as session:
analizer = Analizer(data_shuffler, feature_graph, feature_placeholder, session)
analizer = Analizer(data_shuffler, self.architecture, feature_placeholder, session)
train_writer = tf.train.SummaryWriter('./LOGS/train',
session.graph)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment