Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
bob.learn.tensorflow
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
11
Issues
11
List
Boards
Labels
Milestones
Merge Requests
1
Merge Requests
1
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
bob
bob.learn.tensorflow
Commits
57d0dd6b
Commit
57d0dd6b
authored
Sep 25, 2016
by
Tiago de Freitas Pereira
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Developing triplet loss
parent
2d03b381
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
391 additions
and
212 deletions
+391
-212
bob/learn/tensorflow/data/BaseDataShuffler.py
bob/learn/tensorflow/data/BaseDataShuffler.py
+22
-0
bob/learn/tensorflow/data/MemoryDataShuffler.py
bob/learn/tensorflow/data/MemoryDataShuffler.py
+8
-47
bob/learn/tensorflow/loss/TripletLoss.py
bob/learn/tensorflow/loss/TripletLoss.py
+45
-0
bob/learn/tensorflow/loss/__init__.py
bob/learn/tensorflow/loss/__init__.py
+1
-0
bob/learn/tensorflow/script/train_mnist_siamese.py
bob/learn/tensorflow/script/train_mnist_siamese.py
+1
-1
bob/learn/tensorflow/script/train_mnist_triplet.py
bob/learn/tensorflow/script/train_mnist_triplet.py
+115
-164
bob/learn/tensorflow/trainers/TripletTrainer.py
bob/learn/tensorflow/trainers/TripletTrainer.py
+198
-0
bob/learn/tensorflow/trainers/__init__.py
bob/learn/tensorflow/trainers/__init__.py
+1
-0
No files found.
bob/learn/tensorflow/data/BaseDataShuffler.py
View file @
57d0dd6b
...
...
@@ -97,3 +97,25 @@ class BaseDataShuffler(object):
data_p
=
input_data
[
indexes_p
[
0
],
...]
return
data
,
data_p
def
get_one_triplet
(
self
,
input_data
,
input_labels
):
# Getting a pair of clients
index
=
numpy
.
random
.
choice
(
len
(
self
.
possible_labels
),
2
,
replace
=
False
)
label_positive
=
index
[
0
]
label_negative
=
index
[
1
]
# Getting the indexes of the data from a particular client
indexes
=
numpy
.
where
(
input_labels
==
index
[
0
])[
0
]
numpy
.
random
.
shuffle
(
indexes
)
# Picking a positive pair
data_anchor
=
input_data
[
indexes
[
0
],
...]
data_positive
=
input_data
[
indexes
[
1
],
...]
# Picking a negative sample
indexes
=
numpy
.
where
(
input_labels
==
index
[
1
])[
0
]
numpy
.
random
.
shuffle
(
indexes
)
data_negative
=
input_data
[
indexes
[
0
],
...]
return
data_anchor
,
data_positive
,
data_negative
bob/learn/tensorflow/data/MemoryDataShuffler.py
View file @
57d0dd6b
...
...
@@ -58,7 +58,7 @@ class MemoryDataShuffler(BaseDataShuffler):
return
selected_data
.
astype
(
"float32"
),
selected_labels
def
get_pair
(
self
,
train_dataset
=
True
,
zero_one_labels
=
True
):
def
get_pair
(
self
,
zero_one_labels
=
True
):
"""
Get a random pair of samples
...
...
@@ -82,10 +82,9 @@ class MemoryDataShuffler(BaseDataShuffler):
return
data
,
data_p
,
labels_siamese
def
get_triplet
(
self
,
n_labels
,
n_triplets
=
1
,
is_target_set_train
=
True
):
def
get_random_triplet
(
self
,
n_triplets
=
1
):
"""
Get a triplet
Get a
random
triplet
**Parameters**
is_target_set_train: Defining the target set to get the batch
...
...
@@ -93,50 +92,12 @@ class MemoryDataShuffler(BaseDataShuffler):
**Return**
"""
def
get_one_triplet
(
input_data
,
input_labels
):
# Getting a pair of clients
index
=
numpy
.
random
.
choice
(
n_labels
,
2
,
replace
=
False
)
label_positive
=
index
[
0
]
label_negative
=
index
[
1
]
# Getting the indexes of the data from a particular client
indexes
=
numpy
.
where
(
input_labels
==
index
[
0
])[
0
]
numpy
.
random
.
shuffle
(
indexes
)
# Picking a positive pair
data_anchor
=
input_data
[
indexes
[
0
],
:,
:,
:]
data_positive
=
input_data
[
indexes
[
1
],
:,
:,
:]
# Picking a negative sample
indexes
=
numpy
.
where
(
input_labels
==
index
[
1
])[
0
]
numpy
.
random
.
shuffle
(
indexes
)
data_negative
=
input_data
[
indexes
[
0
],
:,
:,
:]
return
data_anchor
,
data_positive
,
data_negative
,
label_positive
,
label_positive
,
label_negative
if
is_target_set_train
:
target_data
=
self
.
train_data
target_labels
=
self
.
train_labels
else
:
target_data
=
self
.
validation_data
target_labels
=
self
.
validation_labels
c
=
target_data
.
shape
[
3
]
w
=
target_data
.
shape
[
1
]
h
=
target_data
.
shape
[
2
]
data_a
=
numpy
.
zeros
(
shape
=
(
n_triplets
,
w
,
h
,
c
),
dtype
=
'float32'
)
data_p
=
numpy
.
zeros
(
shape
=
(
n_triplets
,
w
,
h
,
c
),
dtype
=
'float32'
)
data_n
=
numpy
.
zeros
(
shape
=
(
n_triplets
,
w
,
h
,
c
),
dtype
=
'float32'
)
labels_a
=
numpy
.
zeros
(
shape
=
n_triplets
,
dtype
=
'float32'
)
labels_p
=
numpy
.
zeros
(
shape
=
n_triplets
,
dtype
=
'float32'
)
labels_n
=
numpy
.
zeros
(
shape
=
n_triplets
,
dtype
=
'float32'
)
data_a
=
numpy
.
zeros
(
shape
=
self
.
shape
,
dtype
=
'float32'
)
data_p
=
numpy
.
zeros
(
shape
=
self
.
shape
,
dtype
=
'float32'
)
data_n
=
numpy
.
zeros
(
shape
=
self
.
shape
,
dtype
=
'float32'
)
for
i
in
range
(
n_triplets
):
data_a
[
i
,
:,
:,
:],
data_p
[
i
,
:,
:,
:],
data_n
[
i
,
:,
:,
:],
\
labels_a
[
i
],
labels_p
[
i
],
labels_n
[
i
]
=
\
get_one_triplet
(
target_data
,
target_labels
)
data_a
[
i
,
...],
data_p
[
i
,
...],
data_n
[
i
,
...]
=
self
.
get_one_triplet
(
self
.
data
,
self
.
labels
)
return
data_a
,
data_p
,
data_n
,
labels_a
,
labels_p
,
labels_n
return
data_a
,
data_p
,
data_n
bob/learn/tensorflow/loss/TripletLoss.py
0 → 100644
View file @
57d0dd6b
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 10 Aug 2016 16:38 CEST
import
logging
logger
=
logging
.
getLogger
(
"bob.learn.tensorflow"
)
import
tensorflow
as
tf
from
.BaseLoss
import
BaseLoss
from
bob.learn.tensorflow.util
import
compute_euclidean_distance
class
TripletLoss
(
BaseLoss
):
"""
Compute the triplet loss as in
Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
"Facenet: A unified embedding for face recognition and clustering."
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
L = sum( |f_a - f_p|^2 - |f_a - f_n|^2 +
\
l
a
mbda)
**Parameters**
left_feature: First element of the pair
right_feature: Second element of the pair
label: Label of the pair (0 or 1)
margin: Contrastive margin
"""
def
__init__
(
self
,
margin
=
2.0
):
self
.
margin
=
margin
def
__call__
(
self
,
anchor_feature
,
positive_feature
,
negative_feature
):
with
tf
.
name_scope
(
"triplet_loss"
):
d_positive
=
tf
.
square
(
compute_euclidean_distance
(
anchor_feature
,
positive_feature
))
d_negative
=
tf
.
square
(
compute_euclidean_distance
(
anchor_feature
,
negative_feature
))
loss
=
tf
.
maximum
(
0.
,
d_positive
-
d_negative
+
self
.
margin
)
return
tf
.
reduce_mean
(
loss
),
tf
.
reduce_mean
(
d_positive
),
tf
.
reduce_mean
(
d_negative
)
bob/learn/tensorflow/loss/__init__.py
View file @
57d0dd6b
...
...
@@ -4,6 +4,7 @@ __path__ = extend_path(__path__, __name__)
from
.BaseLoss
import
BaseLoss
from
.ContrastiveLoss
import
ContrastiveLoss
from
.TripletLoss
import
TripletLoss
# gets sphinx autodoc done right - don't remove it
__all__
=
[
_
for
_
in
dir
()
if
not
_
.
startswith
(
'_'
)]
...
...
bob/learn/tensorflow/script/train_mnist_siamese.py
View file @
57d0dd6b
...
...
@@ -113,7 +113,7 @@ def main():
# Preparing the architecture
n_classes
=
len
(
train_data_shuffler
.
possible_labels
)
n_classes
=
200
#
n_classes = 200
cnn
=
True
if
cnn
:
...
...
bob/learn/tensorflow/script/train_mnist_triplet.py
View file @
57d0dd6b
This diff is collapsed.
Click to expand it.
bob/learn/tensorflow/trainers/TripletTrainer.py
0 → 100644
View file @
57d0dd6b
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST
import
logging
logger
=
logging
.
getLogger
(
"bob.learn.tensorflow"
)
import
tensorflow
as
tf
import
threading
from
..analyzers
import
ExperimentAnalizer
from
..network
import
SequenceNetwork
import
bob.io.base
from
.Trainer
import
Trainer
import
os
import
sys
class
TripletTrainer
(
Trainer
):
def
__init__
(
self
,
architecture
,
optimizer
=
tf
.
train
.
AdamOptimizer
(),
use_gpu
=
False
,
loss
=
None
,
temp_dir
=
"cnn"
,
# Learning rate
base_learning_rate
=
0.001
,
weight_decay
=
0.9
,
###### training options ##########
convergence_threshold
=
0.01
,
iterations
=
5000
,
snapshot
=
100
):
super
(
TripletTrainer
,
self
).
__init__
(
architecture
=
architecture
,
optimizer
=
optimizer
,
use_gpu
=
use_gpu
,
loss
=
loss
,
temp_dir
=
temp_dir
,
base_learning_rate
=
base_learning_rate
,
weight_decay
=
weight_decay
,
convergence_threshold
=
convergence_threshold
,
iterations
=
iterations
,
snapshot
=
snapshot
)
def
train
(
self
,
train_data_shuffler
,
validation_data_shuffler
=
None
):
"""
Do the loop forward --> backward --|
^--------------------|
"""
def
start_thread
():
threads
=
[]
for
n
in
range
(
1
):
t
=
threading
.
Thread
(
target
=
load_and_enqueue
)
t
.
daemon
=
True
# thread will close when parent quits
t
.
start
()
threads
.
append
(
t
)
return
threads
def
load_and_enqueue
():
"""
Injecting data in the place holder queue
"""
#for i in range(self.iterations+5):
while
not
thread_pool
.
should_stop
():
batch_anchor
,
batch_positive
,
batch_negative
=
train_data_shuffler
.
get_random_triplet
()
feed_dict
=
{
train_placeholder_anchor_data
:
batch_anchor
,
train_placeholder_positive_data
:
batch_positive
,
train_placeholder_negative_data
:
batch_negative
}
session
.
run
(
enqueue_op
,
feed_dict
=
feed_dict
)
# TODO: find an elegant way to provide this as a parameter of the trainer
learning_rate
=
tf
.
train
.
exponential_decay
(
self
.
base_learning_rate
,
# Learning rate
train_data_shuffler
.
batch_size
,
train_data_shuffler
.
n_samples
,
self
.
weight_decay
# Decay step
)
# Creating directory
bob
.
io
.
base
.
create_directories_safe
(
self
.
temp_dir
)
# Creating two graphs
#train_placeholder_anchor_data, _ = train_data_shuffler.get_placeholders_forprefetch(name="train_anchor")
#train_placeholder_positive_data, _ = train_data_shuffler.get_placeholders_forprefetch(name="train_positive")
#train_placeholder_negative_data, _ = train_data_shuffler.get_placeholders_forprefetch(name="train_negative")
# Defining a placeholder queue for prefetching
#queue = tf.FIFOQueue(capacity=100,
# dtypes=[tf.float32, tf.float32, tf.float32],
# shapes=[train_placeholder_anchor_data.get_shape().as_list()[1:],
# train_placeholder_positive_data.get_shape().as_list()[1:],
# train_placeholder_negative_data.get_shape().as_list()[1:]])
# Fetching the place holders from the queue
#enqueue_op = queue.enqueue_many([train_placeholder_anchor_data,
# train_placeholder_positive_data,
# train_placeholder_negative_data])
#train_anchor_feature_batch, train_positive_label_batch, train_negative_label_batch = \
# queue.dequeue_many(train_data_shuffler.batch_size)
# Creating the architecture for train and validation
if
not
isinstance
(
self
.
architecture
,
SequenceNetwork
):
raise
ValueError
(
"The variable `architecture` must be an instance of "
"`bob.learn.tensorflow.network.SequenceNetwork`"
)
#############
train_anchor_feature_batch
,
_
=
train_data_shuffler
.
get_placeholders
(
name
=
"train_anchor"
)
train_positive_feature_batch
,
_
=
train_data_shuffler
.
get_placeholders
(
name
=
"train_positive"
)
train_negative_feature_batch
,
_
=
train_data_shuffler
.
get_placeholders
(
name
=
"train_negative"
)
#############
# Creating the siamese graph
#import ipdb; ipdb.set_trace();
train_anchor_graph
=
self
.
architecture
.
compute_graph
(
train_anchor_feature_batch
)
train_positive_graph
=
self
.
architecture
.
compute_graph
(
train_positive_feature_batch
)
train_negative_graph
=
self
.
architecture
.
compute_graph
(
train_negative_feature_batch
)
loss_train
,
within_class
,
between_class
=
self
.
loss
(
train_anchor_graph
,
train_positive_graph
,
train_negative_graph
)
# Preparing the optimizer
step
=
tf
.
Variable
(
0
)
self
.
optimizer
.
_learning_rate
=
learning_rate
optimizer
=
self
.
optimizer
.
minimize
(
loss_train
,
global_step
=
step
)
#optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.99, use_locking=False,
# name='Momentum').minimize(loss_train, global_step=step)
print
(
"Initializing !!"
)
# Training
hdf5
=
bob
.
io
.
base
.
HDF5File
(
os
.
path
.
join
(
self
.
temp_dir
,
'model.hdf5'
),
'w'
)
with
tf
.
Session
()
as
session
:
if
validation_data_shuffler
is
not
None
:
analizer
=
ExperimentAnalizer
(
validation_data_shuffler
,
self
.
architecture
,
session
)
tf
.
initialize_all_variables
().
run
()
# Start a thread to enqueue data asynchronously, and hide I/O latency.
#thread_pool = tf.train.Coordinator()
#tf.train.start_queue_runners(coord=thread_pool)
#threads = start_thread()
# TENSOR BOARD SUMMARY
train_writer
=
tf
.
train
.
SummaryWriter
(
os
.
path
.
join
(
self
.
temp_dir
,
'LOGS'
),
session
.
graph
)
# Siamese specific summary
tf
.
scalar_summary
(
'loss'
,
loss_train
)
tf
.
scalar_summary
(
'between_class'
,
between_class
)
tf
.
scalar_summary
(
'within_class'
,
within_class
)
tf
.
scalar_summary
(
'lr'
,
learning_rate
)
merged
=
tf
.
merge_all_summaries
()
# Architecture summary
self
.
architecture
.
generate_summaries
()
merged_validation
=
tf
.
merge_all_summaries
()
for
step
in
range
(
self
.
iterations
):
batch_anchor
,
batch_positive
,
batch_negative
=
train_data_shuffler
.
get_random_triplet
()
feed_dict
=
{
train_anchor_feature_batch
:
batch_anchor
,
train_positive_feature_batch
:
batch_positive
,
train_negative_feature_batch
:
batch_negative
}
_
,
l
,
lr
,
summary
=
session
.
run
([
optimizer
,
loss_train
,
learning_rate
,
merged
],
feed_dict
=
feed_dict
)
#_, l, lr= session.run([optimizer, loss_train, learning_rate])
train_writer
.
add_summary
(
summary
,
step
)
print
str
(
step
)
+
" -- loss: "
+
str
(
l
)
sys
.
stdout
.
flush
()
if
validation_data_shuffler
is
not
None
and
step
%
self
.
snapshot
==
0
:
#summary = session.run(merged_validation)
#train_writer.add_summary(summary, step)
summary
=
analizer
()
train_writer
.
add_summary
(
summary
,
step
)
print
str
(
step
)
sys
.
stdout
.
flush
()
print
(
"#######DONE##########"
)
self
.
architecture
.
save
(
hdf5
)
del
hdf5
train_writer
.
close
()
#thread_pool.request_stop()
#thread_pool.join(threads)
bob/learn/tensorflow/trainers/__init__.py
View file @
57d0dd6b
...
...
@@ -4,6 +4,7 @@ __path__ = extend_path(__path__, __name__)
from
.Trainer
import
Trainer
from
.SiameseTrainer
import
SiameseTrainer
from
.TripletTrainer
import
TripletTrainer
import
numpy
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment