Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
2ab6b68d
Commit
2ab6b68d
authored
Jan 03, 2017
by
Tiago de Freitas Pereira
Browse files
Fixed issue
#21
parent
5f82ac33
Pipeline
#6176
failed with stages
in 11 minutes and 17 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/datashuffler/Triplet.py
View file @
2ab6b68d
...
...
@@ -23,6 +23,12 @@ class Triplet(Base):
self
.
data2_placeholder
=
None
self
.
data3_placeholder
=
None
def
set_placeholders
(
self
,
data
,
data2
,
data3
):
self
.
data_placeholder
=
data
self
.
data2_placeholder
=
data2
self
.
data3_placeholder
=
data3
def
get_placeholders
(
self
,
name
=
""
):
"""
Returns a place holder with the size of your batch
...
...
bob/learn/tensorflow/test/test_cnn_pretrained_model.py
View file @
2ab6b68d
...
...
@@ -6,12 +6,13 @@
import
numpy
import
bob.io.base
import
os
from
bob.learn.tensorflow.datashuffler
import
Memory
,
ImageAugmentation
from
bob.learn.tensorflow.loss
import
BaseLoss
from
bob.learn.tensorflow.trainers
import
Trainer
,
constant
from
bob.learn.tensorflow.datashuffler
import
Memory
,
ImageAugmentation
,
TripletMemory
,
SiameseMemory
from
bob.learn.tensorflow.loss
import
BaseLoss
,
TripletLoss
,
ContrastiveLoss
from
bob.learn.tensorflow.trainers
import
Trainer
,
constant
,
TripletTrainer
,
SiameseTrainer
from
bob.learn.tensorflow.utils
import
load_mnist
from
bob.learn.tensorflow.network
import
SequenceNetwork
from
bob.learn.tensorflow.layers
import
Conv2D
,
FullyConnected
from
test_cnn
import
dummy_experiment
import
tensorflow
as
tf
import
shutil
...
...
@@ -99,7 +100,7 @@ def test_cnn_pretrained():
scratch
=
scratch_network
()
trainer
=
Trainer
(
architecture
=
scratch
,
loss
=
loss
,
iterations
=
iterations
+
200
,
iterations
=
iterations
+
200
,
analizer
=
None
,
prefetch
=
False
,
learning_rate
=
None
,
...
...
@@ -118,3 +119,144 @@ def test_cnn_pretrained():
del
loss
del
trainer
def
test_triplet_cnn_pretrained
():
train_data
,
train_labels
,
validation_data
,
validation_labels
=
load_mnist
()
train_data
=
numpy
.
reshape
(
train_data
,
(
train_data
.
shape
[
0
],
28
,
28
,
1
))
# Creating datashufflers
data_augmentation
=
ImageAugmentation
()
train_data_shuffler
=
TripletMemory
(
train_data
,
train_labels
,
input_shape
=
[
28
,
28
,
1
],
batch_size
=
batch_size
,
data_augmentation
=
data_augmentation
)
validation_data
=
numpy
.
reshape
(
validation_data
,
(
validation_data
.
shape
[
0
],
28
,
28
,
1
))
validation_data_shuffler
=
TripletMemory
(
validation_data
,
validation_labels
,
input_shape
=
[
28
,
28
,
1
],
batch_size
=
validation_batch_size
)
directory
=
"./temp/cnn"
directory2
=
"./temp/cnn2"
# Creating a random network
scratch
=
scratch_network
()
# Loss for the softmax
loss
=
TripletLoss
(
margin
=
4.
)
# One graph trainer
trainer
=
TripletTrainer
(
architecture
=
scratch
,
loss
=
loss
,
iterations
=
iterations
,
analizer
=
None
,
prefetch
=
False
,
learning_rate
=
constant
(
0.05
,
name
=
"regular_lr"
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
name
=
"adam_pretrained_model"
),
temp_dir
=
directory
)
trainer
.
train
(
train_data_shuffler
)
# Testing
eer
=
dummy_experiment
(
validation_data_shuffler
,
scratch
)
# The result is not so good
assert
eer
<
0.25
del
scratch
del
loss
del
trainer
# Training the network using a pre trained model
loss
=
TripletLoss
(
margin
=
4.
)
scratch
=
scratch_network
()
trainer
=
TripletTrainer
(
architecture
=
scratch
,
loss
=
loss
,
iterations
=
iterations
+
200
,
analizer
=
None
,
prefetch
=
False
,
learning_rate
=
None
,
temp_dir
=
directory2
,
model_from_file
=
os
.
path
.
join
(
directory
,
"model.ckp"
)
)
trainer
.
train
(
train_data_shuffler
)
eer
=
dummy_experiment
(
validation_data_shuffler
,
scratch
)
# Now it is better
assert
eer
<
0.15
shutil
.
rmtree
(
directory
)
shutil
.
rmtree
(
directory2
)
del
scratch
del
loss
del
trainer
def
test_siamese_cnn_pretrained
():
train_data
,
train_labels
,
validation_data
,
validation_labels
=
load_mnist
()
train_data
=
numpy
.
reshape
(
train_data
,
(
train_data
.
shape
[
0
],
28
,
28
,
1
))
# Creating datashufflers
data_augmentation
=
ImageAugmentation
()
train_data_shuffler
=
SiameseMemory
(
train_data
,
train_labels
,
input_shape
=
[
28
,
28
,
1
],
batch_size
=
batch_size
,
data_augmentation
=
data_augmentation
)
validation_data
=
numpy
.
reshape
(
validation_data
,
(
validation_data
.
shape
[
0
],
28
,
28
,
1
))
validation_data_shuffler
=
SiameseMemory
(
validation_data
,
validation_labels
,
input_shape
=
[
28
,
28
,
1
],
batch_size
=
validation_batch_size
)
directory
=
"./temp/cnn"
directory2
=
"./temp/cnn2"
# Creating a random network
scratch
=
scratch_network
()
# Loss for the softmax
loss
=
ContrastiveLoss
(
contrastive_margin
=
4.
)
# One graph trainer
trainer
=
SiameseTrainer
(
architecture
=
scratch
,
loss
=
loss
,
iterations
=
iterations
,
analizer
=
None
,
prefetch
=
False
,
learning_rate
=
constant
(
0.05
,
name
=
"regular_lr"
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
name
=
"adam_pretrained_model"
),
temp_dir
=
directory
)
trainer
.
train
(
train_data_shuffler
)
# Testing
eer
=
dummy_experiment
(
validation_data_shuffler
,
scratch
)
# The result is not so good
assert
eer
<
0.28
del
scratch
del
loss
del
trainer
# Training the network using a pre trained model
loss
=
ContrastiveLoss
(
contrastive_margin
=
4.
)
scratch
=
scratch_network
()
trainer
=
SiameseTrainer
(
architecture
=
scratch
,
loss
=
loss
,
iterations
=
iterations
+
1000
,
analizer
=
None
,
prefetch
=
False
,
learning_rate
=
None
,
temp_dir
=
directory2
,
model_from_file
=
os
.
path
.
join
(
directory
,
"model.ckp"
)
)
trainer
.
train
(
train_data_shuffler
)
eer
=
dummy_experiment
(
validation_data_shuffler
,
scratch
)
# Now it is better
assert
eer
<
0.25
shutil
.
rmtree
(
directory
)
shutil
.
rmtree
(
directory2
)
del
scratch
del
loss
del
trainer
bob/learn/tensorflow/trainers/SiameseTrainer.py
View file @
2ab6b68d
...
...
@@ -151,6 +151,46 @@ class SiameseTrainer(Trainer):
tf
.
get_collection
(
"validation_placeholder_data2"
)[
0
],
tf
.
get_collection
(
"validation_placeholder_label"
)[
0
])
def
bootstrap_graphs
(
self
,
train_data_shuffler
,
validation_data_shuffler
):
"""
Create all the necessary graphs for training, validation and inference graphs
"""
super
(
SiameseTrainer
,
self
).
bootstrap_graphs
(
train_data_shuffler
,
validation_data_shuffler
)
# Triplet specific
tf
.
add_to_collection
(
"between_class_graph_train"
,
self
.
between_class_graph_train
)
tf
.
add_to_collection
(
"within_class_graph_train"
,
self
.
within_class_graph_train
)
# Creating validation graph
if
validation_data_shuffler
is
not
None
:
tf
.
add_to_collection
(
"between_class_graph_validation"
,
self
.
between_class_graph_validation
)
tf
.
add_to_collection
(
"within_class_graph_validation"
,
self
.
within_class_graph_validation
)
self
.
bootstrap_placeholders
(
train_data_shuffler
,
validation_data_shuffler
)
def
bootstrap_graphs_fromfile
(
self
,
train_data_shuffler
,
validation_data_shuffler
):
"""
Bootstrap all the necessary data from file
** Parameters **
session: Tensorflow session
train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation
"""
saver
=
super
(
SiameseTrainer
,
self
).
bootstrap_graphs_fromfile
(
train_data_shuffler
,
validation_data_shuffler
)
self
.
between_class_graph_train
=
tf
.
get_collection
(
"between_class_graph_train"
)[
0
]
self
.
within_class_graph_train
=
tf
.
get_collection
(
"within_class_graph_train"
)[
0
]
if
validation_data_shuffler
is
not
None
:
self
.
between_class_graph_validation
=
tf
.
get_collection
(
"between_class_graph_validation"
)[
0
]
self
.
within_class_graph_validation
=
tf
.
get_collection
(
"within_class_graph_validation"
)[
0
]
self
.
bootstrap_placeholders_fromfile
(
train_data_shuffler
,
validation_data_shuffler
)
return
saver
def
compute_graph
(
self
,
data_shuffler
,
prefetch
=
False
,
name
=
""
,
training
=
True
):
"""
Computes the graph for the trainer.
...
...
bob/learn/tensorflow/trainers/TripletTrainer.py
View file @
2ab6b68d
...
...
@@ -118,6 +118,46 @@ class TripletTrainer(Trainer):
self
.
between_class_graph_validation
=
None
self
.
within_class_graph_validation
=
None
def
bootstrap_graphs
(
self
,
train_data_shuffler
,
validation_data_shuffler
):
"""
Create all the necessary graphs for training, validation and inference graphs
"""
super
(
TripletTrainer
,
self
).
bootstrap_graphs
(
train_data_shuffler
,
validation_data_shuffler
)
# Triplet specific
tf
.
add_to_collection
(
"between_class_graph_train"
,
self
.
between_class_graph_train
)
tf
.
add_to_collection
(
"within_class_graph_train"
,
self
.
within_class_graph_train
)
# Creating validation graph
if
validation_data_shuffler
is
not
None
:
tf
.
add_to_collection
(
"between_class_graph_validation"
,
self
.
between_class_graph_validation
)
tf
.
add_to_collection
(
"within_class_graph_validation"
,
self
.
within_class_graph_validation
)
self
.
bootstrap_placeholders
(
train_data_shuffler
,
validation_data_shuffler
)
def
bootstrap_graphs_fromfile
(
self
,
train_data_shuffler
,
validation_data_shuffler
):
"""
Bootstrap all the necessary data from file
** Parameters **
session: Tensorflow session
train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation
"""
saver
=
super
(
TripletTrainer
,
self
).
bootstrap_graphs_fromfile
(
train_data_shuffler
,
validation_data_shuffler
)
self
.
between_class_graph_train
=
tf
.
get_collection
(
"between_class_graph_train"
)[
0
]
self
.
within_class_graph_train
=
tf
.
get_collection
(
"within_class_graph_train"
)[
0
]
if
validation_data_shuffler
is
not
None
:
self
.
between_class_graph_validation
=
tf
.
get_collection
(
"between_class_graph_validation"
)[
0
]
self
.
within_class_graph_validation
=
tf
.
get_collection
(
"within_class_graph_validation"
)[
0
]
self
.
bootstrap_placeholders_fromfile
(
train_data_shuffler
,
validation_data_shuffler
)
return
saver
def
bootstrap_placeholders
(
self
,
train_data_shuffler
,
validation_data_shuffler
):
"""
Persist the placeholders
...
...
@@ -251,6 +291,7 @@ class TripletTrainer(Trainer):
self
.
within_class_graph_train
,
self
.
learning_rate
,
self
.
summaries_train
],
feed_dict
=
feed_dict
)
logger
.
info
(
"Loss training set step={0} = {1}"
.
format
(
step
,
l
))
self
.
train_summary_writter
.
add_summary
(
summary
,
step
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment