Skip to content
Snippets Groups Projects
Commit 7fb3c1de authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Redefined the projector

parent 4402ab49
Branches
Tags
No related merge requests found
......@@ -46,12 +46,12 @@ class Analizer:
# Extracting features for enrollment
enroll_data, enroll_labels = self.data_shuffler.get_batch(train_dataset=False)
enroll_features = self.machine(enroll_data, self.session)
enroll_features = self.machine(enroll_data, session=self.session)
del enroll_data
# Extracting features for probing
probe_data, probe_labels = self.data_shuffler.get_batch(train_dataset=False)
probe_features = self.machine(probe_data, self.session)
probe_features = self.machine(probe_data, session=self.session)
del probe_data
# Creating models
......
......@@ -24,7 +24,7 @@ class Lenet(SequenceNetwork):
fc1_output=400,
n_classes=10,
feature_layer="fc2",
default_feature_layer="fc2",
seed=10, use_gpu = False):
"""
......@@ -42,7 +42,7 @@ class Lenet(SequenceNetwork):
seed = 10
"""
super(Lenet, self).__init__(feature_layer=feature_layer)
super(Lenet, self).__init__(default_feature_layer=default_feature_layer)
self.add(Conv2D(name="conv1", kernel_size=conv1_kernel_size, filters=conv1_output, activation=tf.nn.tanh))
self.add(MaxPooling(name="pooling1"))
......
......@@ -21,7 +21,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
Base class to create architectures using TensorFlow
"""
def __init__(self, feature_layer=None):
def __init__(self, default_feature_layer=None):
"""
Base constructor
......@@ -30,7 +30,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
"""
self.sequence_net = OrderedDict()
self.feature_layer = feature_layer
self.default_feature_layer = default_feature_layer
self.input_divide = 1.
self.input_subtract = 0.
#self.saver = None
......@@ -44,7 +44,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`")
self.sequence_net[layer.name] = layer
def compute_graph(self, input_data, cut=False):
def compute_graph(self, input_data, feature_layer=None):
"""
Given the current network, return the Tensorflow graph
......@@ -59,15 +59,15 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
current_layer.create_variables(input_offset)
input_offset = current_layer.get_graph()
if cut and k == self.feature_layer:
if feature_layer is not None and k == feature_layer:
return input_offset
return input_offset
def compute_projection_graph(self, placeholder):
return self.compute_graph(placeholder, cut=True)
return self.compute_graph(placeholder)
def __call__(self, data, session=None):
def __call__(self, data, session=None, feature_layer=None):
if session is None:
session = tf.Session()
......@@ -81,7 +81,10 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
feature_placeholder = tf.placeholder(tf.float32, shape=(batch_size, width, height, channels), name="feature")
feed_dict = {feature_placeholder: data}
return session.run([self.compute_projection_graph(feature_placeholder)], feed_dict=feed_dict)[0]
if feature_layer is None:
feature_layer = self.default_feature_layer
return session.run([self.compute_graph(feature_placeholder, feature_layer)], feed_dict=feed_dict)[0]
def dump_variables(self):
......@@ -97,11 +100,6 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
def save(self, hdf5, step=None):
"""
Save the state of the network in HDF5 format
:param session:
:param hdf5:
:param step:
:return:
"""
# Directory that stores the tensorflow variables
......
......@@ -44,7 +44,7 @@ def main():
data_shuffler = PairDataShuffler(data, labels)
# Preparing the architecture
lenet = Lenet(feature_layer="fc2")
lenet = Lenet(default_feature_layer="fc2")
loss = ContrastiveLoss()
trainer = SiameseTrainer(architecture=lenet,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment