Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
db9747e1
Commit
db9747e1
authored
Nov 10, 2016
by
Tiago de Freitas Pereira
Browse files
New session management
parent
0fe1b037
Changes
15
Hide whitespace changes
Inline
Side-by-side
bob/learn/__init__.py
View file @
db9747e1
...
...
@@ -9,4 +9,5 @@ from bob.learn.tensorflow import layers
from
bob.learn.tensorflow
import
loss
from
bob.learn.tensorflow
import
network
from
bob.learn.tensorflow
import
trainers
from
bob.learn.tensorflow
import
utils
bob/learn/tensorflow/__init__.py
View file @
db9747e1
...
...
@@ -2,7 +2,6 @@
from
pkgutil
import
extend_path
__path__
=
extend_path
(
__path__
,
__name__
)
from
.util
import
*
# gets sphinx autodoc done right - don't remove it
__all__
=
[
_
for
_
in
dir
()
if
not
_
.
startswith
(
'_'
)]
bob/learn/tensorflow/layers/AveragePooling.py
View file @
db9747e1
...
...
@@ -4,7 +4,6 @@
# @date: Wed 11 May 2016 17:38 CEST
import
tensorflow
as
tf
from
bob.learn.tensorflow.util
import
*
from
.MaxPooling
import
MaxPooling
...
...
bob/learn/tensorflow/layers/InputLayer.py
View file @
db9747e1
...
...
@@ -4,7 +4,6 @@
# @date: Wed 11 May 2016 17:38 CEST
import
tensorflow
as
tf
from
bob.learn.tensorflow.util
import
*
from
.Layer
import
Layer
...
...
bob/learn/tensorflow/layers/MaxPooling.py
View file @
db9747e1
...
...
@@ -4,7 +4,6 @@
# @date: Wed 11 May 2016 17:38 CEST
import
tensorflow
as
tf
from
bob.learn.tensorflow.util
import
*
from
.Layer
import
Layer
...
...
bob/learn/tensorflow/loss/ContrastiveLoss.py
View file @
db9747e1
...
...
@@ -8,7 +8,7 @@ logger = logging.getLogger("bob.learn.tensorflow")
import
tensorflow
as
tf
from
.BaseLoss
import
BaseLoss
from
bob.learn.tensorflow.util
import
compute_euclidean_distance
from
bob.learn.tensorflow.util
s
import
compute_euclidean_distance
class
ContrastiveLoss
(
BaseLoss
):
...
...
bob/learn/tensorflow/loss/TripletLoss.py
View file @
db9747e1
...
...
@@ -8,7 +8,7 @@ logger = logging.getLogger("bob.learn.tensorflow")
import
tensorflow
as
tf
from
.BaseLoss
import
BaseLoss
from
bob.learn.tensorflow.util
import
compute_euclidean_distance
from
bob.learn.tensorflow.util
s
import
compute_euclidean_distance
class
TripletLoss
(
BaseLoss
):
...
...
bob/learn/tensorflow/network/SequenceNetwork.py
View file @
db9747e1
...
...
@@ -12,6 +12,7 @@ import pickle
from
collections
import
OrderedDict
from
bob.learn.tensorflow.layers
import
Layer
,
MaxPooling
,
Dropout
,
Conv2D
,
FullyConnected
from
bob.learn.tensorflow.utils.session
import
Session
class
SequenceNetwork
(
six
.
with_metaclass
(
abc
.
ABCMeta
,
object
)):
...
...
@@ -102,7 +103,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
def
compute_inference_placeholder
(
self
,
data_shape
):
self
.
inference_placeholder
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
data_shape
,
name
=
"feature"
)
def
__call__
(
self
,
data
,
session
=
None
,
feature_layer
=
None
):
def
__call__
(
self
,
data
,
feature_layer
=
None
):
"""Run a graph and compute the embeddings
**Parameters**
...
...
@@ -115,8 +116,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
If `None` will run the graph until the end.
"""
if
session
is
None
:
session
=
tf
.
Session
()
session
=
Session
.
instance
().
session
# Feeding the placeholder
if
self
.
inference_placeholder
is
None
:
...
...
@@ -130,8 +130,8 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
return
embedding
def
predict
(
self
,
data
,
session
):
return
numpy
.
argmax
(
self
(
data
,
session
=
session
),
1
)
def
predict
(
self
,
data
):
return
numpy
.
argmax
(
self
(
data
),
1
)
def
dump_variables
(
self
):
"""
...
...
@@ -252,10 +252,13 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self
.
sequence_net
[
k
].
weights_initialization
.
use_gpu
=
state
self
.
sequence_net
[
k
].
bias_initialization
.
use_gpu
=
state
def
load_variables_only
(
self
,
hdf5
,
session
):
def
load_variables_only
(
self
,
hdf5
):
"""
Load the variables of the model
"""
session
=
Session
.
instance
().
session
hdf5
.
cd
(
'/tensor_flow'
)
for
k
in
self
.
sequence_net
:
# TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
...
...
@@ -271,7 +274,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
hdf5
.
cd
(
".."
)
def
load_hdf5
(
self
,
hdf5
,
shape
=
None
,
session
=
None
,
batch
=
1
,
use_gpu
=
False
):
def
load_hdf5
(
self
,
hdf5
,
shape
=
None
,
batch
=
1
,
use_gpu
=
False
):
"""
Load the network from scratch.
This will build the graphs
...
...
@@ -285,8 +288,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
use_gpu: Load all the variables in the GPU?
"""
if
session
is
None
:
session
=
tf
.
Session
()
session
=
Session
.
instance
().
session
# Loading the normalization parameters
self
.
input_divide
=
hdf5
.
read
(
'input_divide'
)
...
...
@@ -308,11 +310,17 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
tf
.
initialize_all_variables
().
run
(
session
=
session
)
self
.
load_variables_only
(
hdf5
,
session
)
def
save
(
self
,
session
,
saver
,
path
):
def
save
(
self
,
saver
,
path
):
session
=
Session
.
instance
().
session
open
(
path
+
"_sequence_net.pickle"
,
'w'
).
write
(
self
.
pickle_architecture
)
return
saver
.
save
(
session
,
path
)
def
load
(
self
,
session
,
path
,
clear_devices
=
False
):
def
load
(
self
,
path
,
clear_devices
=
False
):
session
=
Session
.
instance
().
session
self
.
sequence_net
=
pickle
.
loads
(
open
(
path
+
"_sequence_net.pickle"
).
read
())
#saver = tf.train.import_meta_graph(path + ".meta", clear_devices=clear_devices)
saver
=
tf
.
train
.
import_meta_graph
(
path
+
".meta"
)
...
...
bob/learn/tensorflow/test/test_cnn_scratch.py
View file @
db9747e1
...
...
@@ -11,7 +11,7 @@ from bob.learn.tensorflow.initialization import Xavier, Constant
from
bob.learn.tensorflow.network
import
SequenceNetwork
from
bob.learn.tensorflow.loss
import
BaseLoss
from
bob.learn.tensorflow.trainers
import
Trainer
from
bob.learn.tensorflow.util
import
load_mnist
from
bob.learn.tensorflow.util
s
import
load_mnist
from
bob.learn.tensorflow.layers
import
Conv2D
,
FullyConnected
,
MaxPooling
import
tensorflow
as
tf
import
shutil
...
...
@@ -46,22 +46,21 @@ def scratch_network():
return
scratch
def
validate_network
(
validation_data
,
validation_labels
,
direct
or
y
):
def
validate_network
(
validation_data
,
validation_labels
,
netw
or
k
):
# Testing
validation_data_shuffler
=
Memory
(
validation_data
,
validation_labels
,
input_shape
=
[
28
,
28
,
1
],
batch_size
=
validation_batch_size
)
with
tf
.
Session
()
as
session
:
scratch
=
SequenceNetwork
()
scratch
.
load
(
session
,
os
.
path
.
join
(
directory
,
"model.ckp"
))
[
data
,
labels
]
=
validation_data_shuffler
.
get_batch
()
predictions
=
scratch
(
data
,
session
=
session
)
accuracy
=
100.
*
numpy
.
sum
(
numpy
.
argmax
(
predictions
,
1
)
==
labels
)
/
predictions
.
shape
[
0
]
[
data
,
labels
]
=
validation_data_shuffler
.
get_batch
()
predictions
=
network
.
predict
(
data
)
accuracy
=
100.
*
numpy
.
sum
(
predictions
==
labels
)
/
predictions
.
shape
[
0
]
return
accuracy
def
test_cnn_trainer_scratch
():
train_data
,
train_labels
,
validation_data
,
validation_labels
=
load_mnist
()
train_data
=
numpy
.
reshape
(
train_data
,
(
train_data
.
shape
[
0
],
28
,
28
,
1
))
...
...
@@ -86,11 +85,10 @@ def test_cnn_trainer_scratch():
analizer
=
None
,
prefetch
=
False
,
temp_dir
=
directory
)
trainer
.
train
(
train_data_shuffler
)
del
trainer
# JUst to clean the tf.variables
trainer
.
train
(
train_data_shuffler
)
accuracy
=
validate_network
(
validation_data
,
validation_labels
,
directory
)
accuracy
=
validate_network
(
validation_data
,
validation_labels
,
scratch
)
assert
accuracy
>
80
shutil
.
rmtree
(
directory
)
del
trainer
bob/learn/tensorflow/trainers/Trainer.py
View file @
db9747e1
...
...
@@ -13,6 +13,7 @@ from ..analyzers import SoftmaxAnalizer
from
tensorflow.core.framework
import
summary_pb2
import
time
from
bob.learn.tensorflow.datashuffler.OnlineSampling
import
OnLineSampling
from
bob.learn.tensorflow.utils.session
import
Session
from
.learning_rate
import
constant
logger
=
bob
.
core
.
log
.
setup
(
"bob.learn.tensorflow"
)
...
...
@@ -103,6 +104,7 @@ class Trainer(object):
self
.
global_step
=
None
self
.
model_from_file
=
model_from_file
self
.
session
=
None
bob
.
core
.
log
.
set_verbosity_level
(
logger
,
verbosity_level
)
...
...
@@ -162,7 +164,7 @@ class Trainer(object):
label_placeholder
:
labels
}
return
feed_dict
def
fit
(
self
,
session
,
step
):
def
fit
(
self
,
step
):
"""
Run one iteration (`forward` and `backward`)
...
...
@@ -173,17 +175,17 @@ class Trainer(object):
"""
if
self
.
prefetch
:
_
,
l
,
lr
,
summary
=
session
.
run
([
self
.
optimizer
,
self
.
training_graph
,
_
,
l
,
lr
,
summary
=
self
.
session
.
run
([
self
.
optimizer
,
self
.
training_graph
,
self
.
learning_rate
,
self
.
summaries_train
])
else
:
feed_dict
=
self
.
get_feed_dict
(
self
.
train_data_shuffler
)
_
,
l
,
lr
,
summary
=
session
.
run
([
self
.
optimizer
,
self
.
training_graph
,
_
,
l
,
lr
,
summary
=
self
.
session
.
run
([
self
.
optimizer
,
self
.
training_graph
,
self
.
learning_rate
,
self
.
summaries_train
],
feed_dict
=
feed_dict
)
logger
.
info
(
"Loss training set step={0} = {1}"
.
format
(
step
,
l
))
self
.
train_summary_writter
.
add_summary
(
summary
,
step
)
def
compute_validation
(
self
,
session
,
data_shuffler
,
step
):
def
compute_validation
(
self
,
data_shuffler
,
step
):
"""
Computes the loss in the validation set
...
...
@@ -195,10 +197,10 @@ class Trainer(object):
"""
# Opening a new session for validation
feed_dict
=
self
.
get_feed_dict
(
data_shuffler
)
l
=
session
.
run
(
self
.
validation_graph
,
feed_dict
=
feed_dict
)
l
=
self
.
session
.
run
(
self
.
validation_graph
,
feed_dict
=
feed_dict
)
if
self
.
validation_summary_writter
is
None
:
self
.
validation_summary_writter
=
tf
.
train
.
SummaryWriter
(
os
.
path
.
join
(
self
.
temp_dir
,
'validation'
),
session
.
graph
)
self
.
validation_summary_writter
=
tf
.
train
.
SummaryWriter
(
os
.
path
.
join
(
self
.
temp_dir
,
'validation'
),
self
.
session
.
graph
)
summaries
=
[
summary_pb2
.
Summary
.
Value
(
tag
=
"loss"
,
simple_value
=
float
(
l
))]
self
.
validation_summary_writter
.
add_summary
(
summary_pb2
.
Summary
(
value
=
summaries
),
step
)
...
...
@@ -213,7 +215,7 @@ class Trainer(object):
tf
.
scalar_summary
(
'lr'
,
self
.
learning_rate
,
name
=
"train"
)
return
tf
.
merge_all_summaries
()
def
start_thread
(
self
,
session
):
def
start_thread
(
self
):
"""
Start pool of threads for pre-fetching
...
...
@@ -223,13 +225,13 @@ class Trainer(object):
threads
=
[]
for
n
in
range
(
3
):
t
=
threading
.
Thread
(
target
=
self
.
load_and_enqueue
,
args
=
(
session
,
))
t
=
threading
.
Thread
(
target
=
self
.
load_and_enqueue
,
args
=
())
t
.
daemon
=
True
# thread will close when parent quits
t
.
start
()
threads
.
append
(
t
)
return
threads
def
load_and_enqueue
(
self
,
session
):
def
load_and_enqueue
(
self
):
"""
Injecting data in the place holder queue
...
...
@@ -244,7 +246,7 @@ class Trainer(object):
feed_dict
=
{
train_placeholder_data
:
train_data
,
train_placeholder_labels
:
train_labels
}
session
.
run
(
self
.
enqueue_op
,
feed_dict
=
feed_dict
)
self
.
session
.
run
(
self
.
enqueue_op
,
feed_dict
=
feed_dict
)
def
bootstrap_graphs
(
self
,
train_data_shuffler
,
validation_data_shuffler
):
"""
...
...
@@ -293,7 +295,7 @@ class Trainer(object):
tf
.
add_to_collection
(
"validation_placeholder_data"
,
batch
)
tf
.
add_to_collection
(
"validation_placeholder_label"
,
label
)
def
bootstrap_graphs_fromfile
(
self
,
session
,
train_data_shuffler
,
validation_data_shuffler
):
def
bootstrap_graphs_fromfile
(
self
,
train_data_shuffler
,
validation_data_shuffler
):
"""
Bootstrap all the necessary data from file
...
...
@@ -304,7 +306,7 @@ class Trainer(object):
"""
saver
=
self
.
architecture
.
load
(
session
,
self
.
model_from_file
)
saver
=
self
.
architecture
.
load
(
self
.
session
,
self
.
model_from_file
)
# Loading training graph
self
.
training_graph
=
tf
.
get_collection
(
"training_graph"
)[
0
]
...
...
@@ -362,78 +364,80 @@ class Trainer(object):
# Pickle the architecture to save
self
.
architecture
.
pickle_net
(
train_data_shuffler
.
deployment_shape
)
with
tf
.
Session
(
config
=
config
)
as
session
:
# Loading a pretrained model
if
self
.
model_from_file
!=
""
:
logger
.
info
(
"Loading pretrained model from {0}"
.
format
(
self
.
model_from_file
))
saver
=
self
.
bootstrap_graphs_fromfile
(
session
,
train_data_shuffler
,
validation_data_shuffler
)
else
:
# Bootstraping all the graphs
self
.
bootstrap_graphs
(
train_data_shuffler
,
validation_data_shuffler
)
# TODO: find an elegant way to provide this as a parameter of the trainer
self
.
global_step
=
tf
.
Variable
(
0
,
trainable
=
False
)
# Preparing the optimizer
self
.
optimizer_class
.
_learning_rate
=
self
.
learning_rate
self
.
optimizer
=
self
.
optimizer_class
.
minimize
(
self
.
training_graph
,
global_step
=
self
.
global_step
)
tf
.
add_to_collection
(
"optimizer"
,
self
.
optimizer
)
tf
.
add_to_collection
(
"learning_rate"
,
self
.
learning_rate
)
# Train summary
self
.
summaries_train
=
self
.
create_general_summary
()
tf
.
add_to_collection
(
"summaries_train"
,
self
.
summaries_train
)
tf
.
initialize_all_variables
().
run
()
# Original tensorflow saver object
saver
=
tf
.
train
.
Saver
(
var_list
=
tf
.
all_variables
())
if
isinstance
(
train_data_shuffler
,
OnLineSampling
):
train_data_shuffler
.
set_feature_extractor
(
self
.
architecture
,
session
=
session
)
# Start a thread to enqueue data asynchronously, and hide I/O latency.
if
self
.
prefetch
:
self
.
thread_pool
=
tf
.
train
.
Coordinator
()
tf
.
train
.
start_queue_runners
(
coord
=
self
.
thread_pool
)
threads
=
self
.
start_thread
(
session
)
# TENSOR BOARD SUMMARY
self
.
train_summary_writter
=
tf
.
train
.
SummaryWriter
(
os
.
path
.
join
(
self
.
temp_dir
,
'train'
),
session
.
graph
)
for
step
in
range
(
self
.
iterations
):
start
=
time
.
time
()
self
.
fit
(
session
,
step
)
end
=
time
.
time
()
summary
=
summary_pb2
.
Summary
.
Value
(
tag
=
"elapsed_time"
,
simple_value
=
float
(
end
-
start
))
self
.
train_summary_writter
.
add_summary
(
summary_pb2
.
Summary
(
value
=
[
summary
]),
step
)
# Running validation
if
validation_data_shuffler
is
not
None
and
step
%
self
.
validation_snapshot
==
0
:
self
.
compute_validation
(
session
,
validation_data_shuffler
,
step
)
if
self
.
analizer
is
not
None
:
self
.
validation_summary_writter
.
add_summary
(
self
.
analizer
(
validation_data_shuffler
,
self
.
architecture
,
session
),
step
)
# Taking snapshot
if
step
%
self
.
snapshot
==
0
:
logger
.
info
(
"Taking snapshot"
)
path
=
os
.
path
.
join
(
self
.
temp_dir
,
'model_snapshot{0}.ckp'
.
format
(
step
))
self
.
architecture
.
save
(
session
,
saver
,
path
)
logger
.
info
(
"Training finally finished"
)
self
.
train_summary_writter
.
close
()
if
validation_data_shuffler
is
not
None
:
self
.
validation_summary_writter
.
close
()
# Saving the final network
path
=
os
.
path
.
join
(
self
.
temp_dir
,
'model.ckp'
)
self
.
architecture
.
save
(
session
,
saver
,
path
)
if
self
.
prefetch
:
# now they should definetely stop
self
.
thread_pool
.
request_stop
()
self
.
thread_pool
.
join
(
threads
)
#with tf.Session(config=config) as session:
self
.
session
=
Session
.
instance
().
session
# Loading a pretrained model
if
self
.
model_from_file
!=
""
:
logger
.
info
(
"Loading pretrained model from {0}"
.
format
(
self
.
model_from_file
))
saver
=
self
.
bootstrap_graphs_fromfile
(
self
.
session
,
train_data_shuffler
,
validation_data_shuffler
)
else
:
# Bootstraping all the graphs
self
.
bootstrap_graphs
(
train_data_shuffler
,
validation_data_shuffler
)
# TODO: find an elegant way to provide this as a parameter of the trainer
self
.
global_step
=
tf
.
Variable
(
0
,
trainable
=
False
)
# Preparing the optimizer
self
.
optimizer_class
.
_learning_rate
=
self
.
learning_rate
self
.
optimizer
=
self
.
optimizer_class
.
minimize
(
self
.
training_graph
,
global_step
=
self
.
global_step
)
tf
.
add_to_collection
(
"optimizer"
,
self
.
optimizer
)
tf
.
add_to_collection
(
"learning_rate"
,
self
.
learning_rate
)
# Train summary
self
.
summaries_train
=
self
.
create_general_summary
()
tf
.
add_to_collection
(
"summaries_train"
,
self
.
summaries_train
)
tf
.
initialize_all_variables
().
run
(
session
=
self
.
session
)
# Original tensorflow saver object
saver
=
tf
.
train
.
Saver
(
var_list
=
tf
.
all_variables
())
if
isinstance
(
train_data_shuffler
,
OnLineSampling
):
train_data_shuffler
.
set_feature_extractor
(
self
.
architecture
,
session
=
self
.
session
)
# Start a thread to enqueue data asynchronously, and hide I/O latency.
if
self
.
prefetch
:
self
.
thread_pool
=
tf
.
train
.
Coordinator
()
tf
.
train
.
start_queue_runners
(
coord
=
self
.
thread_pool
)
threads
=
self
.
start_thread
(
self
.
session
)
# TENSOR BOARD SUMMARY
self
.
train_summary_writter
=
tf
.
train
.
SummaryWriter
(
os
.
path
.
join
(
self
.
temp_dir
,
'train'
),
self
.
session
.
graph
)
for
step
in
range
(
self
.
iterations
):
start
=
time
.
time
()
self
.
fit
(
self
.
session
,
step
)
end
=
time
.
time
()
summary
=
summary_pb2
.
Summary
.
Value
(
tag
=
"elapsed_time"
,
simple_value
=
float
(
end
-
start
))
self
.
train_summary_writter
.
add_summary
(
summary_pb2
.
Summary
(
value
=
[
summary
]),
step
)
# Running validation
if
validation_data_shuffler
is
not
None
and
step
%
self
.
validation_snapshot
==
0
:
self
.
compute_validation
(
self
.
session
,
validation_data_shuffler
,
step
)
if
self
.
analizer
is
not
None
:
self
.
validation_summary_writter
.
add_summary
(
self
.
analizer
(
validation_data_shuffler
,
self
.
architecture
,
self
.
session
),
step
)
# Taking snapshot
if
step
%
self
.
snapshot
==
0
:
logger
.
info
(
"Taking snapshot"
)
path
=
os
.
path
.
join
(
self
.
temp_dir
,
'model_snapshot{0}.ckp'
.
format
(
step
))
self
.
architecture
.
save
(
saver
,
path
)
logger
.
info
(
"Training finally finished"
)
self
.
train_summary_writter
.
close
()
if
validation_data_shuffler
is
not
None
:
self
.
validation_summary_writter
.
close
()
# Saving the final network
path
=
os
.
path
.
join
(
self
.
temp_dir
,
'model.ckp'
)
self
.
architecture
.
save
(
saver
,
path
)
if
self
.
prefetch
:
# now they should definetely stop
self
.
thread_pool
.
request_stop
()
self
.
thread_pool
.
join
(
threads
)
bob/learn/tensorflow/utils/__init__.py
0 → 100644
View file @
db9747e1
from
.util
import
*
from
.singleton
import
Singleton
from
.session
import
Session
\ No newline at end of file
bob/learn/tensorflow/utils/session.py
0 → 100644
View file @
db9747e1
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import
tensorflow
as
tf
from
.singleton
import
Singleton
@
Singleton
class
Session
(
object
):
def
__init__
(
self
):
config
=
tf
.
ConfigProto
(
log_device_placement
=
True
,
gpu_options
=
tf
.
GPUOptions
(
per_process_gpu_memory_fraction
=
0.333
))
config
.
gpu_options
.
allow_growth
=
True
self
.
session
=
tf
.
Session
()
#def __del__(self):
# self.session.close()
\ No newline at end of file
bob/learn/tensorflow/utils/singleton.py
0 → 100644
View file @
db9747e1
# A singleton class decorator, based on http://stackoverflow.com/a/7346105/3301902
class
Singleton
(
object
):
"""
A non-thread-safe helper class to ease implementing singletons.
This should be used as a **decorator** -- not a metaclass -- to the class that should be a singleton.
The decorated class can define one `__init__` function that takes an arbitrary list of parameters.
To get the singleton instance, use the :py:meth:`instance` method. Trying to use `__call__` will result in a `TypeError` being raised.
Limitations:
* The decorated class cannot be inherited from.
* The documentation of the decorated class is replaced with the documentation of this class.
"""
def
__init__
(
self
,
decorated
):
self
.
_decorated
=
decorated
# see: functools.WRAPPER_ASSIGNMENTS:
self
.
__doc__
=
decorated
.
__doc__
self
.
__name__
=
decorated
.
__name__
self
.
__module__
=
decorated
.
__module__
self
.
__mro__
=
decorated
.
__mro__
self
.
__bases__
=
[]
self
.
_instance
=
None
def
create
(
self
,
*
args
,
**
kwargs
):
"""Creates the singleton instance, by passing the given parameters to the class' constructor."""
self
.
_instance
=
self
.
_decorated
(
*
args
,
**
kwargs
)
def
instance
(
self
):
"""Returns the singleton instance.
The function :py:meth:`create` must have been called before."""
if
self
.
_instance
is
None
:
self
.
create
()
return
self
.
_instance
def
__call__
(
self
):
raise
TypeError
(
'Singletons must be accessed through the `instance()` method.'
)
def
__instancecheck__
(
self
,
inst
):
return
isinstance
(
inst
,
self
.
_decorated
)
bob/learn/tensorflow/util.py
→
bob/learn/tensorflow/
utils/
util.py
View file @
db9747e1
File moved
doc/user_guide.rst
View file @
db9747e1
...
...
@@ -63,7 +63,7 @@ Now lets describe each step in detail.
Preparing your input data
-------------------------
.........................
In this library datasets are wrapped in **data shufflers**. Data shufflers are elements designed to shuffle
the input data for stochastic training.
...
...
@@ -73,24 +73,39 @@ It is possible to either use Memory (:py:class:`bob.learn.tensorflow.datashuffle
Disk (:py:class:`bob.learn.tensorflow.datashuffler.Disk`) data shufflers.
For the Memory data shufflers, as in the example, it is expected that the dataset is stored in `numpy.array`.
In the example that we provided the MNIST dataset was loaded and
reshaped to `[n, w, h, c]` where `n` is the size
of the batch, `w` and `h` are the image width and height and `c` is the
In the example that we provided the MNIST dataset was loaded and
reshaped to `[n, w, h, c]` where `n` is the size
of the batch, `w` and `h` are the image width and height and `c` is the
number of channels.
Creating the architecture
-------------------------
.........................
Architectures are assembled as a :py:class:`bob.learn.tensorflow.network.SequenceNetwork` object.
Once the objects are created it necessary to fill it up with
:py_api:
`Layers`_.
The library has already some crafted networks `Architectures`_
Once the objects are created it
is
necessary to fill it up with `Layers`_.
The library has already some crafted networks
implemented in
`Architectures`_
Defining a loss and training
----------------------------
............................
The loss function can be defined by any set of tensorflow operations.
In our example, we used the `tf.nn.sparse_softmax_cross_entropy_with_logits` loss, but we also have some crafted
loss functions for Siamese :py:class`bob.learn.tensorflow.loss.ContrastiveLoss` and Triplet networks :py:class`bob.learn.tensorflow.loss.TripletLoss`.
Predicting and computing the accuracy
-------------------------------------
The trainer is the real muscle here.
This element takes the inputs and trains the network.
As for the loss, we have specific trainers for Siamese (:py:class:`bob.learn.tensorflow.trainers.SiameseTrainer`) a
nd Triplet networks (:py:class:`bob.learn.tensorflow.trainers.TripletTrainer`).
Sandbox
-------
We have a sandbox of examples in a git repository `https://gitlab.idiap.ch/tiago.pereira/bob.learn.tensorflow_sandbox`_
The sandbox has some example of training:
- MNIST with softmax
- MNIST with Siamese Network
- MNIST with Triplet Network
- Face recognition with MOBIO database
- Face recognition with CASIA WebFace database
Write
Preview