Skip to content
Snippets Groups Projects
Commit 877374f7 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented transfer learning mechanism with the triplet estimator

parent 9acb50fa
Branches
Tags
1 merge request!31Implemented transfer learning mechanism with the triplet estimator
Pipeline #
......@@ -13,8 +13,7 @@ import time
from tensorflow.python.estimator import estimator
from bob.learn.tensorflow.utils import predict_using_tensors
from bob.learn.tensorflow.loss import triplet_loss
from . import check_features
from . import check_features, is_trainable_checkpoint
import logging
logger = logging.getLogger("bob.learn")
......@@ -70,23 +69,31 @@ class Triplet(estimator.Estimator):
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
extra_checkpoint: dict()
In case you want to use other model to initialize some variables.
This argument should be in the following format
extra_checkpoint = {"checkpoint_path": <YOUR_CHECKPOINT>,
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"is_trainable": <IF_THOSE_LOADED_VARIABLES_ARE_TRAINABLE>
}
"""
def __init__(self,
architecture=None,
optimizer=None,
config=None,
n_classes=0,
loss_op=triplet_loss,
model_dir="",
validation_batch_size=None,
extra_checkpoint=None
):
self.architecture = architecture
self.optimizer=optimizer
self.n_classes=n_classes
self.loss_op=loss_op
self.loss = None
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
......@@ -97,9 +104,6 @@ class Triplet(estimator.Estimator):
if self.loss_op is None:
raise ValueError("Please specify a function to build the loss !!")
if self.n_classes <= 0:
raise ValueError("Number of classes must be greated than 0")
def _model_fn(features, labels, mode, params, config):
if mode == tf.estimator.ModeKeys.TRAIN:
......@@ -110,11 +114,20 @@ class Triplet(estimator.Estimator):
'negative' in features.keys():
raise ValueError("The input function needs to contain a dictionary with the "
"keys `anchor`, `positive` and `negative` ")
if self.extra_checkpoint is None:
is_trainable = True
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# Building one graph
prelogits_anchor = self.architecture(features['anchor'])[0]
prelogits_positive = self.architecture(features['positive'], reuse=True)[0]
prelogits_negative = self.architecture(features['negative'], reuse=True)[0]
prelogits_anchor = self.architecture(features['anchor'], is_trainable=is_trainable)[0]
prelogits_positive = self.architecture(features['positive'], reuse=True, is_trainable=is_trainable)[0]
prelogits_negative = self.architecture(features['negative'], reuse=True, is_trainable=is_trainable)[0]
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(prelogits_anchor, prelogits_positive, prelogits_negative)
......
......@@ -5,14 +5,15 @@
import tensorflow as tf
from bob.learn.tensorflow.network import dummy
from bob.learn.tensorflow.estimators import Triplet
from bob.learn.tensorflow.estimators import Triplet, Logits
from bob.learn.tensorflow.dataset.triplet_image import shuffle_data_and_labels_image_augmentation as triplet_batch
from bob.learn.tensorflow.dataset.image import shuffle_data_and_labels_image_augmentation as single_batch
from bob.learn.tensorflow.loss import triplet_loss
from bob.learn.tensorflow.loss import triplet_loss, mean_cross_entropy_loss
from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
from bob.learn.tensorflow.utils import reproducible
import pkg_resources
from .test_estimator_transfer import dummy_adapted
import numpy
import shutil
......@@ -22,6 +23,7 @@ import os
tfrecord_train = "./train_mnist.tfrecord"
tfrecord_validation = "./validation_mnist.tfrecord"
model_dir = "./temp"
model_dir_adapted = "./temp2"
learning_rate = 0.001
data_shape = (250, 250, 3) # size of atnt images
......@@ -77,6 +79,45 @@ def test_triplet_estimator():
pass
def test_triplettrainer_transfer():
def logits_input_fn():
return single_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, output_shape=output_shape)
# Trainer logits first than siamese
try:
extra_checkpoint = {"checkpoint_path":model_dir,
"scopes": dict({"Dummy/": "Dummy/"}),
"is_trainable": False
}
# LOGISTS
logits_trainer = Logits(model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=2,
loss_op=mean_cross_entropy_loss,
embedding_validation=False,
validation_batch_size=validation_batch_size)
logits_trainer.train(logits_input_fn, steps=steps)
# NOW THE FUCKING SIAMESE
trainer = Triplet(model_dir=model_dir_adapted,
architecture=dummy_adapted,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
loss_op=triplet_loss,
validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint)
run_triplet_estimator(trainer)
finally:
try:
shutil.rmtree(model_dir, ignore_errors=True)
shutil.rmtree(model_dir_adapted, ignore_errors=True)
except Exception:
pass
def run_triplet_estimator(trainer):
# Cleaning up
......@@ -105,4 +146,3 @@ def run_triplet_estimator(trainer):
# Cleaning up
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment