Skip to content
Snippets Groups Projects
Commit 6d4910cc authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed Facenet

parent e34b58c0
Branches
Tags v0.0.1b12
No related merge requests found
...@@ -184,13 +184,13 @@ class FaceNet(SequenceNetwork): ...@@ -184,13 +184,13 @@ class FaceNet(SequenceNetwork):
self.add(MaxPooling(name="pooling6", shape=pool6_shape)) self.add(MaxPooling(name="pooling6", shape=pool6_shape))
self.add(FullyConnected(name="fc1", output_dim=fc1_output, self.add(FullyConnected(name="fc1", output_dim=fc1_output,
activation=tf.nn.tanh, activation=tf.nn.relu,
weights_initialization=Xavier(seed=seed, use_gpu=self.use_gpu), weights_initialization=Xavier(seed=seed, use_gpu=self.use_gpu),
bias_initialization=Constant(use_gpu=self.use_gpu) bias_initialization=Constant(use_gpu=self.use_gpu)
)) ))
self.add(FullyConnected(name="fc2", output_dim=fc2_output, self.add(FullyConnected(name="fc2", output_dim=fc2_output,
activation=tf.nn.tanh, activation=tf.nn.relu,
weights_initialization=Xavier(seed=seed, use_gpu=self.use_gpu), weights_initialization=Xavier(seed=seed, use_gpu=self.use_gpu),
bias_initialization=Constant(use_gpu=self.use_gpu) bias_initialization=Constant(use_gpu=self.use_gpu)
)) ))
......
...@@ -220,7 +220,12 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -220,7 +220,12 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
hdf5.set('input_divide', self.input_divide) hdf5.set('input_divide', self.input_divide)
hdf5.set('input_subtract', self.input_subtract) hdf5.set('input_subtract', self.input_subtract)
def load(self, hdf5, shape=None, session=None): def turn_gpu_onoff(self, state=True):
for k in self.sequence_net:
self.sequence_net[k].weights_initialization.use_gpu = state
self.sequence_net[k].bias_initialization.use_gpu = state
def load(self, hdf5, shape=None, session=None, batch=1):
""" """
Load the network Load the network
...@@ -240,12 +245,15 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -240,12 +245,15 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self.input_divide = hdf5.read('input_divide') self.input_divide = hdf5.read('input_divide')
self.input_subtract = hdf5.read('input_subtract') self.input_subtract = hdf5.read('input_subtract')
# Saving the architecture # Loading architecture
self.sequence_net = pickle.loads(hdf5.read('architecture')) self.sequence_net = pickle.loads(hdf5.read('architecture'))
self.deployment_shape = hdf5.read('deployment_shape') self.deployment_shape = hdf5.read('deployment_shape')
self.turn_gpu_onoff(False)
if shape is None: if shape is None:
shape = self.deployment_shape shape = self.deployment_shape
shape[0] = batch
# Loading variables # Loading variables
place_holder = tf.placeholder(tf.float32, shape=shape, name="load") place_holder = tf.placeholder(tf.float32, shape=shape, name="load")
......
...@@ -60,11 +60,16 @@ def main(): ...@@ -60,11 +60,16 @@ def main():
# input_shape=[125, 125, 3], # input_shape=[125, 125, 3],
# batch_size=BATCH_SIZE) # batch_size=BATCH_SIZE)
train_data_shuffler = TripletWithFastSelectionDisk(train_file_names, train_labels, #train_data_shuffler = TripletWithFastSelectionDisk(train_file_names, train_labels,
# input_shape=[112, 112, 3],
# batch_size=BATCH_SIZE)
train_data_shuffler = TripletDisk(train_file_names, train_labels,
input_shape=[112, 112, 3], input_shape=[112, 112, 3],
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
# Preparing train set # Preparing train set
directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA_WEBFACE/mobio/preprocessed" directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA_WEBFACE/mobio/preprocessed"
validation_objects = sorted(db_mobio.objects(protocol="male", groups="dev"), key=lambda x: x.id) validation_objects = sorted(db_mobio.objects(protocol="male", groups="dev"), key=lambda x: x.id)
...@@ -91,13 +96,13 @@ def main(): ...@@ -91,13 +96,13 @@ def main():
# snapshot=VALIDATION_TEST, # snapshot=VALIDATION_TEST,
# optimizer=optimizer) # optimizer=optimizer)
loss = TripletLoss(margin=0.5) loss = TripletLoss(margin=0.2)
trainer = TripletTrainer(architecture=architecture, loss=loss, trainer = TripletTrainer(architecture=architecture, loss=loss,
iterations=ITERATIONS, iterations=ITERATIONS,
base_learning_rate=0.1, base_learning_rate=0.05,
prefetch=False, prefetch=False,
temp_dir="./LOGS_CASIA/triplet-cnn-fast-selection") temp_dir="/idiap/temp/tpereira/CNN_MODELS/triplet-cnn-RANDOM-selection-gpu")
trainer.train(train_data_shuffler, validation_data_shuffler) #trainer.train(train_data_shuffler, validation_data_shuffler)
#trainer.train(train_data_shuffler) trainer.train(train_data_shuffler)
...@@ -14,8 +14,8 @@ from tensorflow.core.framework import summary_pb2 ...@@ -14,8 +14,8 @@ from tensorflow.core.framework import summary_pb2
import time import time
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling
os.environ["CUDA_VISIBLE_DEVICES"] = "3,2,0,1" #os.environ["CUDA_VISIBLE_DEVICES"] = "1,3,0,2"
#os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["CUDA_VISIBLE_DEVICES"] = ""
logger = bob.core.log.setup("bob.learn.tensorflow") logger = bob.core.log.setup("bob.learn.tensorflow")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment