Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
35180690
Commit
35180690
authored
Oct 05, 2017
by
Tiago de Freitas Pereira
Browse files
Implemented center loss
parent
a3e44720
Changes
3
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/loss/BaseLoss.py
View file @
35180690
...
...
@@ -61,7 +61,7 @@ class MeanSoftMaxLossCenterLoss(object):
Mean softmax loss. Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
"""
def
__init__
(
self
,
name
=
"loss"
,
add_regularization_losses
=
True
,
alpha
=
0.9
,
factor
=
0.01
,
n_classes
=
10
):
def
__init__
(
self
,
name
=
"loss"
,
alpha
=
0.9
,
factor
=
0.01
,
n_classes
=
10
):
"""
Constructor
...
...
@@ -73,46 +73,36 @@ class MeanSoftMaxLossCenterLoss(object):
"""
self
.
name
=
name
self
.
add_regularization_losses
=
add_regularization_losses
self
.
n_classes
=
n_classes
self
.
alpha
=
alpha
self
.
factor
=
factor
def
append_center_loss
(
self
,
features
,
label
):
nrof_features
=
features
.
get_shape
()[
1
]
centers
=
tf
.
get_variable
(
'centers'
,
[
self
.
n_classes
,
nrof_features
],
dtype
=
tf
.
float32
,
initializer
=
tf
.
constant_initializer
(
0
),
trainable
=
False
)
label
=
tf
.
reshape
(
label
,
[
-
1
])
centers_batch
=
tf
.
gather
(
centers
,
label
)
diff
=
(
1
-
self
.
alpha
)
*
(
centers_batch
-
features
)
centers
=
tf
.
scatter_sub
(
centers
,
label
,
diff
)
loss
=
tf
.
reduce_mean
(
tf
.
square
(
features
-
centers_batch
))
return
loss
def
__call__
(
self
,
logits_prelogits
,
label
):
#TODO: Test the dictionary
logits
=
logits_prelogits
[
'logits'
]
def
__call__
(
self
,
logits
,
prelogits
,
label
):
# Cross entropy
loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
),
name
=
self
.
name
)
with
tf
.
variable_scope
(
'cross_entropy_loss'
):
loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
),
name
=
self
.
name
)
# Appending center loss
prelogits
=
logits_prelogits
[
'prelogits'
]
center_loss
=
self
.
append_center_loss
(
prelogits
,
label
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
,
center_loss
*
self
.
factor
)
# Appending center loss
with
tf
.
variable_scope
(
'center_loss'
):
n_features
=
prelogits
.
get_shape
()[
1
]
centers
=
tf
.
get_variable
(
'centers'
,
[
self
.
n_classes
,
n_features
],
dtype
=
tf
.
float32
,
initializer
=
tf
.
constant_initializer
(
0
),
trainable
=
False
)
label
=
tf
.
reshape
(
label
,
[
-
1
])
centers_batch
=
tf
.
gather
(
centers
,
label
)
diff
=
(
1
-
self
.
alpha
)
*
(
centers_batch
-
prelogits
)
centers
=
tf
.
scatter_sub
(
centers
,
label
,
diff
)
center_loss
=
tf
.
reduce_mean
(
tf
.
square
(
prelogits
-
centers_batch
))
tf
.
add_to_collection
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
,
center_loss
*
self
.
factor
)
# Adding the regularizers in the loss
if
self
.
add_regularization
_loss
es
:
with
tf
.
variable_scope
(
'total
_loss
'
)
:
regularization_losses
=
tf
.
get_collection
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
loss
=
tf
.
add_n
([
loss
]
+
regularization_losses
,
name
=
'total_loss'
)
total_
loss
=
tf
.
add_n
([
loss
]
+
regularization_losses
,
name
=
'total_loss'
)
return
loss
return
total_loss
,
centers
bob/learn/tensorflow/trainers/SiameseTrainer.py
View file @
35180690
...
...
@@ -219,7 +219,6 @@ class SiameseTrainer(Trainer):
return
feed_dict
def
fit
(
self
,
step
):
feed_dict
=
self
.
get_feed_dict
(
self
.
train_data_shuffler
)
_
,
l
,
bt_class
,
wt_class
,
lr
,
summary
=
self
.
session
.
run
([
self
.
optimizer
,
...
...
bob/learn/tensorflow/trainers/Trainer.py
View file @
35180690
...
...
@@ -177,7 +177,7 @@ class Trainer(object):
self
.
compute_validation
(
step
)
# Taking snapshot
if
step
%
self
.
snapshot
==
0
:
if
step
%
self
.
snapshot
==
0
:
logger
.
info
(
"Taking snapshot"
)
path
=
os
.
path
.
join
(
self
.
temp_dir
,
'model_snapshot{0}.ckp'
.
format
(
step
))
self
.
saver
.
save
(
self
.
session
,
path
,
global_step
=
step
)
...
...
@@ -214,6 +214,7 @@ class Trainer(object):
# Learning rate
learning_rate
=
None
,
prelogits
=
None
):
"""
...
...
@@ -229,7 +230,6 @@ class Trainer(object):
learning_rate: Learning rate
"""
# Getting the pointer to the placeholders
self
.
data_ph
=
self
.
train_data_shuffler
(
"data"
,
from_queue
=
True
)
self
.
label_ph
=
self
.
train_data_shuffler
(
"label"
,
from_queue
=
True
)
...
...
@@ -237,8 +237,13 @@ class Trainer(object):
self
.
graph
=
graph
self
.
loss
=
loss
# Attaching the loss in the graph
self
.
predictor
=
self
.
loss
(
self
.
graph
,
self
.
label_ph
)
# TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
self
.
centers
=
None
if
prelogits
is
not
None
:
tf
.
add_to_collection
(
"prelogits"
,
prelogits
)
self
.
predictor
,
self
.
centers
=
self
.
loss
(
self
.
graph
,
prelogits
,
self
.
label_ph
)
else
:
self
.
predictor
=
self
.
loss
(
self
.
graph
,
self
.
label_ph
)
self
.
optimizer_class
=
optimizer
self
.
learning_rate
=
learning_rate
...
...
@@ -257,11 +262,8 @@ class Trainer(object):
# SAving some variables
tf
.
add_to_collection
(
"global_step"
,
self
.
global_step
)
if
isinstance
(
self
.
graph
,
dict
):
tf
.
add_to_collection
(
"graph"
,
self
.
graph
[
'logits'
])
tf
.
add_to_collection
(
"prelogits"
,
self
.
graph
[
'prelogits'
])
else
:
tf
.
add_to_collection
(
"graph"
,
self
.
graph
)
tf
.
add_to_collection
(
"graph"
,
self
.
graph
)
tf
.
add_to_collection
(
"predictor"
,
self
.
predictor
)
...
...
@@ -273,6 +275,10 @@ class Trainer(object):
tf
.
add_to_collection
(
"summaries_train"
,
self
.
summaries_train
)
# Appending histograms for each trainable variables
for
var
in
tf
.
trainable_variables
():
tf
.
summary
.
histogram
(
var
.
op
.
name
,
var
)
# Same business with the validation
if
self
.
validation_data_shuffler
is
not
None
:
self
.
validation_data_ph
=
self
.
validation_data_shuffler
(
"data"
,
from_queue
=
True
)
...
...
@@ -280,9 +286,9 @@ class Trainer(object):
self
.
validation_graph
=
validation_graph
if
self
.
validate_with_embeddings
:
if
self
.
validate_with_embeddings
:
self
.
validation_predictor
=
self
.
validation_graph
else
:
else
:
self
.
validation_predictor
=
self
.
loss
(
self
.
validation_graph
,
self
.
validation_label_ph
)
self
.
summaries_validation
=
self
.
create_general_summary
(
self
.
validation_predictor
,
self
.
validation_graph
,
self
.
validation_label_ph
)
...
...
@@ -318,13 +324,13 @@ class Trainer(object):
self
.
saver
=
tf
.
train
.
import_meta_graph
(
file_name
,
clear_devices
=
clear_devices
)
self
.
saver
.
restore
(
self
.
session
,
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
file_name
)))
def
load_variables_from_external_model
(
self
,
file_name
,
var_list
):
def
load_variables_from_external_model
(
self
,
checkpoint_path
,
var_list
):
"""
Load a set of variables from a given model and update them in the current one
** Parameters **
file_name
:
checkpoint_path
:
Name of the tensorflow model to be loaded
var_list:
List of variables to be loaded. A tensorflow exception will be raised in case the variable does not exists
...
...
@@ -338,7 +344,7 @@ class Trainer(object):
tf_varlist
+=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
v
)
saver
=
tf
.
train
.
Saver
(
tf_varlist
)
saver
.
restore
(
self
.
session
,
file_name
)
saver
.
restore
(
self
.
session
,
tf
.
train
.
latest_checkpoint
(
checkpoint_path
)
)
def
create_network_from_file
(
self
,
file_name
,
clear_devices
=
True
):
"""
...
...
@@ -406,8 +412,14 @@ class Trainer(object):
"""
if
self
.
train_data_shuffler
.
prefetch
:
_
,
l
,
lr
,
summary
=
self
.
session
.
run
([
self
.
optimizer
,
self
.
predictor
,
self
.
learning_rate
,
self
.
summaries_train
])
# TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
if
self
.
centers
is
None
:
_
,
l
,
lr
,
summary
=
self
.
session
.
run
([
self
.
optimizer
,
self
.
predictor
,
self
.
learning_rate
,
self
.
summaries_train
])
else
:
_
,
l
,
lr
,
summary
,
_
=
self
.
session
.
run
([
self
.
optimizer
,
self
.
predictor
,
self
.
learning_rate
,
self
.
summaries_train
,
self
.
centers
])
else
:
feed_dict
=
self
.
get_feed_dict
(
self
.
train_data_shuffler
)
_
,
l
,
lr
,
summary
=
self
.
session
.
run
([
self
.
optimizer
,
self
.
predictor
,
...
...
@@ -473,10 +485,7 @@ class Trainer(object):
tf
.
summary
.
scalar
(
'lr'
,
self
.
learning_rate
)
# Computing accuracy
if
isinstance
(
output
,
dict
):
correct_prediction
=
tf
.
equal
(
tf
.
argmax
(
output
[
'logits'
],
1
),
label
)
else
:
correct_prediction
=
tf
.
equal
(
tf
.
argmax
(
output
,
1
),
label
)
correct_prediction
=
tf
.
equal
(
tf
.
argmax
(
output
,
1
),
label
)
accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
correct_prediction
,
tf
.
float32
))
tf
.
summary
.
scalar
(
'accuracy'
,
accuracy
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment