Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
bob.learn.tensorflow
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
11
Issues
11
List
Boards
Labels
Milestones
Merge Requests
1
Merge Requests
1
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
bob
bob.learn.tensorflow
Commits
f7fac18d
Commit
f7fac18d
authored
Oct 30, 2016
by
Tiago de Freitas Pereira
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Preparing batch normalization
parent
5993034c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
82 additions
and
25 deletions
+82
-25
bob/learn/tensorflow/layers/Conv2D.py
bob/learn/tensorflow/layers/Conv2D.py
+6
-1
bob/learn/tensorflow/layers/Dropout.py
bob/learn/tensorflow/layers/Dropout.py
+1
-1
bob/learn/tensorflow/layers/FullyConnected.py
bob/learn/tensorflow/layers/FullyConnected.py
+11
-6
bob/learn/tensorflow/layers/Layer.py
bob/learn/tensorflow/layers/Layer.py
+41
-1
bob/learn/tensorflow/layers/MaxPooling.py
bob/learn/tensorflow/layers/MaxPooling.py
+9
-3
bob/learn/tensorflow/network/__init__.py
bob/learn/tensorflow/network/__init__.py
+1
-0
bob/learn/tensorflow/script/train_mobio.py
bob/learn/tensorflow/script/train_mobio.py
+13
-10
bob/learn/tensorflow/trainers/Trainer.py
bob/learn/tensorflow/trainers/Trainer.py
+0
-3
No files found.
bob/learn/tensorflow/layers/Conv2D.py
View file @
f7fac18d
...
...
@@ -21,6 +21,7 @@ class Conv2D(Layer):
stride
=
[
1
,
1
,
1
,
1
],
weights_initialization
=
Xavier
(),
bias_initialization
=
Constant
(),
batch_norm
=
False
,
use_gpu
=
False
):
"""
...
...
@@ -39,6 +40,7 @@ class Conv2D(Layer):
activation
=
activation
,
weights_initialization
=
weights_initialization
,
bias_initialization
=
bias_initialization
,
batch_norm
=
batch_norm
,
use_gpu
=
use_gpu
,
...
...
@@ -69,11 +71,14 @@ class Conv2D(Layer):
scope
=
"b_"
+
str
(
self
.
name
)
)
def
get_graph
(
self
):
def
get_graph
(
self
,
training_phase
=
True
):
with
tf
.
name_scope
(
str
(
self
.
name
)):
conv2d
=
tf
.
nn
.
conv2d
(
self
.
input_layer
,
self
.
W
,
strides
=
self
.
stride
,
padding
=
'SAME'
)
if
self
.
batch_norm
:
conv2d
=
self
.
batch_normalize
(
conv2d
,
training_phase
)
if
self
.
activation
is
not
None
:
output
=
self
.
activation
(
tf
.
nn
.
bias_add
(
conv2d
,
self
.
b
))
else
:
...
...
bob/learn/tensorflow/layers/Dropout.py
View file @
f7fac18d
...
...
@@ -37,7 +37,7 @@ class Dropout(Layer):
self
.
input_layer
=
input_layer
return
def
get_graph
(
self
):
def
get_graph
(
self
,
training_phase
=
True
):
with
tf
.
name_scope
(
str
(
self
.
name
)):
output
=
tf
.
nn
.
dropout
(
self
.
input_layer
,
self
.
keep_prob
,
name
=
self
.
name
)
...
...
bob/learn/tensorflow/layers/FullyConnected.py
View file @
f7fac18d
...
...
@@ -21,6 +21,7 @@ class FullyConnected(Layer):
activation
=
None
,
weights_initialization
=
Xavier
(),
bias_initialization
=
Constant
(),
batch_norm
=
False
,
use_gpu
=
False
,
):
"""
...
...
@@ -35,10 +36,12 @@ class FullyConnected(Layer):
"""
super
(
FullyConnected
,
self
).
__init__
(
name
=
name
,
activation
=
activation
,
weights_initialization
=
weights_initialization
,
bias_initialization
=
bias_initialization
,
use_gpu
=
use_gpu
)
activation
=
activation
,
weights_initialization
=
weights_initialization
,
bias_initialization
=
bias_initialization
,
batch_norm
=
batch_norm
,
use_gpu
=
use_gpu
)
self
.
output_dim
=
output_dim
self
.
W
=
None
...
...
@@ -59,7 +62,7 @@ class FullyConnected(Layer):
scope
=
"b_"
+
str
(
self
.
name
)
)
def
get_graph
(
self
):
def
get_graph
(
self
,
training_phase
=
True
):
with
tf
.
name_scope
(
str
(
self
.
name
)):
...
...
@@ -67,10 +70,12 @@ class FullyConnected(Layer):
shape
=
self
.
input_layer
.
get_shape
().
as_list
()
#fc = tf.reshape(self.input_layer, [shape[0], shape[1] * shape[2] * shape[3]])
fc
=
tf
.
reshape
(
self
.
input_layer
,
[
-
1
,
shape
[
1
]
*
shape
[
2
]
*
shape
[
3
]])
else
:
fc
=
self
.
input_layer
if
self
.
batch_norm
:
fc
=
self
.
batch_normalize
(
fc
,
training_phase
)
if
self
.
activation
is
not
None
:
non_linear_fc
=
self
.
activation
(
tf
.
matmul
(
fc
,
self
.
W
)
+
self
.
b
)
output
=
non_linear_fc
...
...
bob/learn/tensorflow/layers/Layer.py
View file @
f7fac18d
...
...
@@ -18,6 +18,7 @@ class Layer(object):
activation
=
None
,
weights_initialization
=
Xavier
(),
bias_initialization
=
Constant
(),
batch_norm
=
False
,
use_gpu
=
False
):
"""
Base constructor
...
...
@@ -34,6 +35,7 @@ class Layer(object):
self
.
weights_initialization
=
weights_initialization
self
.
bias_initialization
=
bias_initialization
self
.
use_gpu
=
use_gpu
self
.
batch_norm
=
batch_norm
self
.
input_layer
=
None
self
.
activation
=
activation
...
...
@@ -41,5 +43,43 @@ class Layer(object):
def
create_variables
(
self
,
input_layer
):
NotImplementedError
(
"Please implement this function in derived classes"
)
def
get_graph
(
self
):
def
get_graph
(
self
,
training_phase
=
True
):
NotImplementedError
(
"Please implement this function in derived classes"
)
def
batch_normalize
(
self
,
x
,
phase_train
):
"""
Batch normalization on convolutional maps.
Args:
x: Tensor, 4D BHWD input maps
n_out: integer, depth of input maps
phase_train: boolean tf.Variable, true indicates training phase
scope: string, variable scope
affn: whether to affn-transform outputs
Return:
normed: batch-normalized maps
Ref: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow/33950177
"""
from
tensorflow.python.ops
import
control_flow_ops
name
=
'batch_norm'
with
tf
.
variable_scope
(
name
):
phase_train
=
tf
.
convert_to_tensor
(
phase_train
,
dtype
=
tf
.
bool
)
n_out
=
int
(
x
.
get_shape
()[
3
])
beta
=
tf
.
Variable
(
tf
.
constant
(
0.0
,
shape
=
[
n_out
],
dtype
=
x
.
dtype
),
name
=
name
+
'/beta'
,
trainable
=
True
,
dtype
=
x
.
dtype
)
gamma
=
tf
.
Variable
(
tf
.
constant
(
1.0
,
shape
=
[
n_out
],
dtype
=
x
.
dtype
),
name
=
name
+
'/gamma'
,
trainable
=
True
,
dtype
=
x
.
dtype
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
name
=
'moments'
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
0.9
)
def
mean_var_with_update
():
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
with
tf
.
control_dependencies
([
ema_apply_op
]):
return
tf
.
identity
(
batch_mean
),
tf
.
identity
(
batch_var
)
mean
,
var
=
control_flow_ops
.
cond
(
phase_train
,
mean_var_with_update
,
lambda
:
(
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)))
normed
=
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
1e-3
)
return
normed
bob/learn/tensorflow/layers/MaxPooling.py
View file @
f7fac18d
...
...
@@ -10,11 +10,14 @@ from .Layer import Layer
class
MaxPooling
(
Layer
):
def
__init__
(
self
,
name
,
shape
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
1
,
1
,
1
],
activation
=
None
):
def
__init__
(
self
,
name
,
shape
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
1
,
1
,
1
],
batch_norm
=
False
,
activation
=
None
):
"""
Constructor
"""
super
(
MaxPooling
,
self
).
__init__
(
name
,
use_gpu
=
False
,
activation
=
activation
)
super
(
MaxPooling
,
self
).
__init__
(
name
,
use_gpu
=
False
,
activation
=
activation
,
batch_norm
=
batch_norm
)
self
.
shape
=
shape
self
.
strides
=
strides
...
...
@@ -22,10 +25,13 @@ class MaxPooling(Layer):
self
.
input_layer
=
input_layer
return
def
get_graph
(
self
):
def
get_graph
(
self
,
training_phase
=
True
):
with
tf
.
name_scope
(
str
(
self
.
name
)):
output
=
tf
.
nn
.
max_pool
(
self
.
input_layer
,
ksize
=
self
.
shape
,
strides
=
self
.
strides
,
padding
=
'SAME'
)
if
self
.
batch_norm
:
output
=
self
.
batch_normalize
(
output
,
training_phase
)
if
self
.
activation
is
not
None
:
output
=
self
.
activation
(
output
)
...
...
bob/learn/tensorflow/network/__init__.py
View file @
f7fac18d
...
...
@@ -11,6 +11,7 @@ from .LenetDropout import LenetDropout
from
.MLP
import
MLP
from
.FaceNet
import
FaceNet
from
.FaceNetSimple
import
FaceNetSimple
from
.VGG16
import
VGG16
# gets sphinx autodoc done right - don't remove it
__all__
=
[
_
for
_
in
dir
()
if
not
_
.
startswith
(
'_'
)]
bob/learn/tensorflow/script/train_mobio.py
View file @
f7fac18d
...
...
@@ -23,11 +23,15 @@ import tensorflow as tf
from
..
import
util
SEED
=
10
from
bob.learn.tensorflow.datashuffler
import
TripletWithSelectionDisk
,
TripletDisk
,
TripletWithFastSelectionDisk
from
bob.learn.tensorflow.network
import
Lenet
,
MLP
,
LenetDropout
,
VGG
,
Chopra
,
Dummy
,
FaceNet
,
FaceNetSimple
from
bob.learn.tensorflow.trainers
import
SiameseTrainer
,
Trainer
,
TripletTrainer
from
bob.learn.tensorflow.network
import
Lenet
,
MLP
,
LenetDropout
,
VGG
,
Chopra
,
Dummy
,
FaceNet
,
FaceNetSimple
,
VGG16
from
bob.learn.tensorflow.trainers
import
SiameseTrainer
,
Trainer
,
TripletTrainer
,
constant
from
bob.learn.tensorflow.loss
import
ContrastiveLoss
,
BaseLoss
,
TripletLoss
import
numpy
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0,1,2,3"
#os.environ["CUDA_VISIBLE_DEVICES"] = ""
def
main
():
args
=
docopt
(
__doc__
,
version
=
'Mnist training with TensorFlow'
)
...
...
@@ -41,7 +45,8 @@ def main():
import
bob.db.mobio
db_mobio
=
bob
.
db
.
mobio
.
Database
()
directory
=
"/idiap/temp/tpereira/DEEP_FACE/CASIA_WEBFACE/mobio/preprocessed/"
#directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA_WEBFACE/mobio/preprocessed/"
directory
=
"./preprocessed/"
# Preparing train set
#train_objects = db_mobio.objects(protocol="male", groups="world")
...
...
@@ -78,9 +83,7 @@ def main():
# Preparing the architecture
#architecture = Chopra(seed=SEED, fc1_output=n_classes)
#architecture = FaceNet(seed=SEED, use_gpu=USE_GPU)
architecture
=
FaceNetSimple
(
seed
=
SEED
,
use_gpu
=
USE_GPU
)
#optimizer = tf.train.GradientDescentOptimizer(0.0005)
architecture
=
VGG16
(
seed
=
SEED
,
use_gpu
=
USE_GPU
)
#loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
...
...
@@ -98,14 +101,14 @@ def main():
# temp_dir="./LOGS_MOBIO/siamese-cnn-prefetch")
loss
=
TripletLoss
(
margin
=
0.2
)
#optimizer = tf.train.GradientDescentOptimizer(0.000000000001)
#optimizer = optimizer,
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
0.05
)
trainer
=
TripletTrainer
(
architecture
=
architecture
,
loss
=
loss
,
iterations
=
ITERATIONS
,
base_learning_rate
=
0.05
,
learning_rate
=
constant
(
0.05
)
,
prefetch
=
False
,
optimizer
=
optimizer
,
snapshot
=
200
,
temp_dir
=
"
/idiap/temp/tpereira/CNN_MODELS/triplet-cnn-all-mobio
"
)
temp_dir
=
"
./logs/
"
)
#trainer.train(train_data_shuffler, validation_data_shuffler)
trainer
.
train
(
train_data_shuffler
)
bob/learn/tensorflow/trainers/Trainer.py
View file @
f7fac18d
...
...
@@ -15,9 +15,6 @@ import time
from
bob.learn.tensorflow.datashuffler.OnlineSampling
import
OnLineSampling
from
.learning_rate
import
constant
#os.environ["CUDA_VISIBLE_DEVICES"] = "1,3,0,2"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
""
logger
=
bob
.
core
.
log
.
setup
(
"bob.learn.tensorflow"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment