Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
2d03b381
Commit
2d03b381
authored
Sep 23, 2016
by
Tiago de Freitas Pereira
Browse files
Fixed synchronization bug
parent
8bc92f21
Changes
7
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/analyzers/__init__.py
View file @
2d03b381
...
...
@@ -2,7 +2,7 @@
from
pkgutil
import
extend_path
__path__
=
extend_path
(
__path__
,
__name__
)
from
.Analizer
import
Analizer
from
.
Experiment
Analizer
import
Experiment
Analizer
# gets sphinx autodoc done right - don't remove it
__all__
=
[
_
for
_
in
dir
()
if
not
_
.
startswith
(
'_'
)]
...
...
bob/learn/tensorflow/data/BaseDataShuffler.py
View file @
2d03b381
...
...
@@ -66,10 +66,11 @@ class BaseDataShuffler(object):
return
data
,
labels
def
get_genuine_or_not
(
self
,
input_data
,
input_labels
,
genuine
=
True
):
if
genuine
:
# Getting a client
index
=
numpy
.
random
.
randint
(
len
(
self
.
possible_labels
))
index
=
self
.
possible_labels
[
index
]
index
=
int
(
self
.
possible_labels
[
index
]
)
# Getting the indexes of the data from a particular client
indexes
=
numpy
.
where
(
input_labels
==
index
)[
0
]
...
...
@@ -82,8 +83,8 @@ class BaseDataShuffler(object):
else
:
# Picking a pair of labels from different clients
index
=
numpy
.
random
.
choice
(
len
(
self
.
possible_labels
),
2
,
replace
=
False
)
index
[
0
]
=
self
.
possible_labels
[
index
[
0
]]
index
[
1
]
=
self
.
possible_labels
[
index
[
1
]]
index
[
0
]
=
self
.
possible_labels
[
int
(
index
[
0
]
)
]
index
[
1
]
=
self
.
possible_labels
[
int
(
index
[
1
]
)
]
# Getting the indexes of the two clients
indexes
=
numpy
.
where
(
input_labels
==
index
[
0
])[
0
]
...
...
bob/learn/tensorflow/script/train_mnist.py
View file @
2d03b381
...
...
@@ -89,8 +89,8 @@ def main():
# Preparing the architecture
cnn
=
True
if
cnn
:
#
architecture = Lenet(seed=SEED)
architecture
=
Dummy
(
seed
=
SEED
)
architecture
=
Lenet
(
seed
=
SEED
)
#
architecture = Dummy(seed=SEED)
loss
=
BaseLoss
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
,
tf
.
reduce_mean
)
trainer
=
Trainer
(
architecture
=
architecture
,
loss
=
loss
,
iterations
=
ITERATIONS
)
trainer
.
train
(
train_data_shuffler
,
validation_data_shuffler
)
...
...
bob/learn/tensorflow/script/train_mnist_siamese.py
View file @
2d03b381
...
...
@@ -23,7 +23,7 @@ import tensorflow as tf
from
..
import
util
SEED
=
10
from
bob.learn.tensorflow.data
import
MemoryDataShuffler
,
TextDataShuffler
from
bob.learn.tensorflow.network
import
Lenet
,
MLP
,
LenetDropout
,
VGG
,
Chopra
from
bob.learn.tensorflow.network
import
Lenet
,
MLP
,
LenetDropout
,
VGG
,
Chopra
,
Dummy
from
bob.learn.tensorflow.trainers
import
SiameseTrainer
from
bob.learn.tensorflow.loss
import
ContrastiveLoss
import
numpy
...
...
@@ -39,7 +39,7 @@ def main():
perc_train
=
0.9
# Loading data
mnist
=
Tru
e
mnist
=
Fals
e
if
mnist
:
train_data
,
train_labels
,
validation_data
,
validation_labels
=
\
...
...
@@ -58,55 +58,70 @@ def main():
batch_size
=
VALIDATION_BATCH_SIZE
)
else
:
import
bob.db.
atnt
db
=
bob
.
db
.
atnt
.
Database
()
import
bob.db.
mobio
db
_mobio
=
bob
.
db
.
mobio
.
Database
()
#
import bob.db.
mobio
#
db = bob.db.
mobio
.Database()
import
bob.db.
casia_webface
db
_casia
=
bob
.
db
.
casia_webface
.
Database
()
# Preparing train set
#train_objects = db.objects(protocol="male", groups="world")
train_objects
=
db
.
objects
(
groups
=
"world"
)
train_labels
=
[
o
.
client_id
for
o
in
train_objects
]
#directory = "/idiap/user/tpereira/face/baselines/eigenface/preprocessed",
train_objects
=
db_casia
.
objects
(
groups
=
"world"
)
#train_objects = db.objects(groups="world")
train_labels
=
[
int
(
o
.
client_id
)
for
o
in
train_objects
]
directory
=
"/idiap/resource/database/CASIA-WebFace/CASIA-WebFace"
train_file_names
=
[
o
.
make_path
(
directory
=
"/idiap/group/biometric/databases/orl"
,
extension
=
"
.pgm
"
)
directory
=
directory
,
extension
=
""
)
for
o
in
train_objects
]
#import ipdb;
#ipdb.set_trace();
#train_file_names = [o.make_path(
# directory="/idiap/group/biometric/databases/orl",
# extension=".pgm")
# for o in train_objects]
#train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
# input_shape=[80, 64, 1],
# batch_size=BATCH_SIZE)
train_data_shuffler
=
TextDataShuffler
(
train_file_names
,
train_labels
,
input_shape
=
[
56
,
46
,
1
],
input_shape
=
[
250
,
250
,
3
],
batch_size
=
BATCH_SIZE
)
#train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
# input_shape=[56, 46, 1],
# batch_size=BATCH_SIZE)
# Preparing train set
#validation_objects = db.objects(protocol="male", groups="dev")
validation_objects
=
db
.
objects
(
groups
=
"dev"
)
directory
=
"/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed"
validation_objects
=
db
_mobio
.
objects
(
protocol
=
"male"
,
groups
=
"dev"
)
validation_labels
=
[
o
.
client_id
for
o
in
validation_objects
]
#validation_file_names = [o.make_path(
# directory="/idiap/group/biometric/databases/orl",
# extension=".pgm")
# for o in validation_objects]
validation_file_names
=
[
o
.
make_path
(
directory
=
"/idiap/group/biometric/databases/orl"
,
extension
=
".
pgm
"
)
directory
=
directory
,
extension
=
".
hdf5
"
)
for
o
in
validation_objects
]
#validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
# input_shape=[80, 64, 1],
# batch_size=VALIDATION_BATCH_SIZE)
validation_data_shuffler
=
TextDataShuffler
(
validation_file_names
,
validation_labels
,
input_shape
=
[
56
,
46
,
1
],
input_shape
=
[
250
,
250
,
3
],
batch_size
=
VALIDATION_BATCH_SIZE
)
#validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
# input_shape=[56, 46, 1],
# batch_size=VALIDATION_BATCH_SIZE)
# Preparing the architecture
n_classes
=
len
(
train_data_shuffler
.
possible_labels
)
n_classes
=
200
cnn
=
True
if
cnn
:
# LENET PAPER CHOPRA
#architecture = Chopra(default_feature_layer="fc7")
architecture
=
Lenet
(
default_feature_layer
=
"fc2"
,
n_classes
=
n_classes
,
conv1_output
=
4
,
conv2_output
=
8
,
use_gpu
=
USE_GPU
)
architecture
=
Lenet
(
default_feature_layer
=
"fc2"
,
n_classes
=
n_classes
,
conv1_output
=
8
,
conv2_output
=
16
,
use_gpu
=
USE_GPU
)
#architecture = VGG(n_classes=n_classes, use_gpu=USE_GPU)
#architecture = Dummy(seed=SEED)
#architecture = LenetDropout(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8, use_gpu=USE_GPU)
...
...
@@ -115,7 +130,8 @@ def main():
trainer
=
SiameseTrainer
(
architecture
=
architecture
,
loss
=
loss
,
iterations
=
ITERATIONS
,
snapshot
=
VALIDATION_TEST
)
snapshot
=
VALIDATION_TEST
,
)
trainer
.
train
(
train_data_shuffler
,
validation_data_shuffler
)
else
:
mlp
=
MLP
(
n_classes
,
hidden_layers
=
[
15
,
20
])
...
...
bob/learn/tensorflow/trainers/SiameseTrainer.py
View file @
2d03b381
...
...
@@ -12,7 +12,7 @@ from ..network import SequenceNetwork
import
bob.io.base
from
.Trainer
import
Trainer
import
os
import
sys
class
SiameseTrainer
(
Trainer
):
...
...
@@ -64,7 +64,8 @@ class SiameseTrainer(Trainer):
"""
Injecting data in the place holder queue
"""
for
i
in
range
(
self
.
iterations
):
#for i in range(self.iterations+5):
while
not
thread_pool
.
should_stop
():
batch_left
,
batch_right
,
labels
=
train_data_shuffler
.
get_pair
()
feed_dict
=
{
train_placeholder_left_data
:
batch_left
,
...
...
@@ -151,13 +152,13 @@ class SiameseTrainer(Trainer):
self
.
architecture
.
generate_summaries
()
merged_validation
=
tf
.
merge_all_summaries
()
for
step
in
range
(
self
.
iterations
):
_
,
l
,
lr
,
summary
=
session
.
run
([
optimizer
,
loss_train
,
learning_rate
,
merged
])
#_, l, lr= session.run([optimizer, loss_train, learning_rate])
train_writer
.
add_summary
(
summary
,
step
)
print
str
(
step
)
sys
.
stdout
.
flush
()
if
validation_data_shuffler
is
not
None
and
step
%
self
.
snapshot
==
0
:
...
...
@@ -167,7 +168,9 @@ class SiameseTrainer(Trainer):
summary
=
analizer
()
train_writer
.
add_summary
(
summary
,
step
)
print
str
(
step
)
sys
.
stdout
.
flush
()
print
(
"#######DONE##########"
)
self
.
architecture
.
save
(
hdf5
)
del
hdf5
train_writer
.
close
()
...
...
bob/learn/tensorflow/trainers/Trainer.py
View file @
2d03b381
...
...
@@ -79,7 +79,8 @@ class Trainer(object):
"""
#while not thread_pool.should_stop():
for
i
in
range
(
self
.
iterations
):
#for i in range(self.iterations):
while
not
thread_pool
.
should_stop
():
train_data
,
train_labels
=
train_data_shuffler
.
get_batch
()
feed_dict
=
{
train_placeholder_data
:
train_data
,
...
...
buildout.cfg
View file @
2d03b381
...
...
@@ -5,6 +5,7 @@
[buildout]
parts = scripts
eggs = bob.learn.tensorflow
bob.db.casia_webface
gridtk
extensions = bob.buildout
...
...
@@ -12,6 +13,9 @@ extensions = bob.buildout
auto-checkout = *
develop = src/bob.db.mnist
src/gridtk
src/bob.db.casia_webface
src/bob.db.mobio
src/bob.db.lfw
.
; options for bob.buildout
...
...
@@ -21,7 +25,11 @@ newest = false
[sources]
bob.db.mnist = git git@github.com:tiagofrepereira2012/bob.db.mnist
bob.db.mnist = git git@github.com:tiagofrepereira2012/bob.db.mnist.git
bob.db.base = git git@gitlab.idiap.ch:bob/bob.db.base.git
bob.db.mobio = git git@gitlab.idiap.ch:bob/bob.db.mobio.git
bob.db.lfw = git git@gitlab.idiap.ch:bob/bob.db.lfw.git
bob.db.casia_webface = git git@gitlab.idiap.ch:bob/bob.db.casia_webface.git
gridtk = git git@github.com:bioidiap/gridtk
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment