Skip to content
Snippets Groups Projects
Commit a01b861f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'triplet-transfer-learning' into 'master'

Implemented transfer learning mechanism with the triplet estimator

See merge request !31
parents 9acb50fa 877374f7
No related branches found
No related tags found
1 merge request!31Implemented transfer learning mechanism with the triplet estimator
Pipeline #
......@@ -13,8 +13,7 @@ import time
from tensorflow.python.estimator import estimator
from bob.learn.tensorflow.utils import predict_using_tensors
from bob.learn.tensorflow.loss import triplet_loss
from . import check_features
from . import check_features, is_trainable_checkpoint
import logging
logger = logging.getLogger("bob.learn")
......@@ -70,23 +69,31 @@ class Triplet(estimator.Estimator):
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
extra_checkpoint: dict()
In case you want to use other model to initialize some variables.
This argument should be in the following format
extra_checkpoint = {"checkpoint_path": <YOUR_CHECKPOINT>,
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"is_trainable": <IF_THOSE_LOADED_VARIABLES_ARE_TRAINABLE>
}
"""
def __init__(self,
architecture=None,
optimizer=None,
config=None,
n_classes=0,
loss_op=triplet_loss,
model_dir="",
validation_batch_size=None,
extra_checkpoint=None
):
self.architecture = architecture
self.optimizer=optimizer
self.n_classes=n_classes
self.loss_op=loss_op
self.loss = None
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
......@@ -97,9 +104,6 @@ class Triplet(estimator.Estimator):
if self.loss_op is None:
raise ValueError("Please specify a function to build the loss !!")
if self.n_classes <= 0:
raise ValueError("Number of classes must be greated than 0")
def _model_fn(features, labels, mode, params, config):
if mode == tf.estimator.ModeKeys.TRAIN:
......@@ -110,11 +114,20 @@ class Triplet(estimator.Estimator):
'negative' in features.keys():
raise ValueError("The input function needs to contain a dictionary with the "
"keys `anchor`, `positive` and `negative` ")
if self.extra_checkpoint is None:
is_trainable = True
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# Building one graph
prelogits_anchor = self.architecture(features['anchor'])[0]
prelogits_positive = self.architecture(features['positive'], reuse=True)[0]
prelogits_negative = self.architecture(features['negative'], reuse=True)[0]
prelogits_anchor = self.architecture(features['anchor'], is_trainable=is_trainable)[0]
prelogits_positive = self.architecture(features['positive'], reuse=True, is_trainable=is_trainable)[0]
prelogits_negative = self.architecture(features['negative'], reuse=True, is_trainable=is_trainable)[0]
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(prelogits_anchor, prelogits_positive, prelogits_negative)
......
......@@ -5,14 +5,15 @@
import tensorflow as tf
from bob.learn.tensorflow.network import dummy
from bob.learn.tensorflow.estimators import Triplet
from bob.learn.tensorflow.estimators import Triplet, Logits
from bob.learn.tensorflow.dataset.triplet_image import shuffle_data_and_labels_image_augmentation as triplet_batch
from bob.learn.tensorflow.dataset.image import shuffle_data_and_labels_image_augmentation as single_batch
from bob.learn.tensorflow.loss import triplet_loss
from bob.learn.tensorflow.loss import triplet_loss, mean_cross_entropy_loss
from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
from bob.learn.tensorflow.utils import reproducible
import pkg_resources
from .test_estimator_transfer import dummy_adapted
import numpy
import shutil
......@@ -22,6 +23,7 @@ import os
tfrecord_train = "./train_mnist.tfrecord"
tfrecord_validation = "./validation_mnist.tfrecord"
model_dir = "./temp"
model_dir_adapted = "./temp2"
learning_rate = 0.001
data_shape = (250, 250, 3) # size of atnt images
......@@ -77,6 +79,45 @@ def test_triplet_estimator():
pass
def test_triplettrainer_transfer():
def logits_input_fn():
return single_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, output_shape=output_shape)
# Trainer logits first than siamese
try:
extra_checkpoint = {"checkpoint_path":model_dir,
"scopes": dict({"Dummy/": "Dummy/"}),
"is_trainable": False
}
# LOGISTS
logits_trainer = Logits(model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=2,
loss_op=mean_cross_entropy_loss,
embedding_validation=False,
validation_batch_size=validation_batch_size)
logits_trainer.train(logits_input_fn, steps=steps)
# NOW THE FUCKING SIAMESE
trainer = Triplet(model_dir=model_dir_adapted,
architecture=dummy_adapted,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
loss_op=triplet_loss,
validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint)
run_triplet_estimator(trainer)
finally:
try:
shutil.rmtree(model_dir, ignore_errors=True)
shutil.rmtree(model_dir_adapted, ignore_errors=True)
except Exception:
pass
def run_triplet_estimator(trainer):
# Cleaning up
......@@ -105,4 +146,3 @@ def run_triplet_estimator(trainer):
# Cleaning up
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment