Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
9e3b0c00
Commit
9e3b0c00
authored
Oct 27, 2020
by
Amir MOHAMMADI
Browse files
Add a callback that backups and restores other callbacks
parent
58427600
Pipeline
#44928
failed with stage
in 2 minutes and 29 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/callbacks.py
0 → 100644
View file @
9e3b0c00
import
json
import
os
import
tensorflow
as
tf
from
tensorflow.keras
import
callbacks
class
CustomBackupAndRestore
(
tf
.
keras
.
callbacks
.
experimental
.
BackupAndRestore
):
"""This callback is experimental and might be removed in future.
See :any:`add_backup_callback`
"""
def
__init__
(
self
,
callbacks
,
backup_dir
,
**
kwargs
):
super
().
__init__
(
backup_dir
=
backup_dir
,
**
kwargs
)
self
.
callbacks
=
callbacks
self
.
callbacks_backup_path
=
os
.
path
.
join
(
self
.
backup_dir
,
"callbacks.json"
)
def
backup
(
self
):
variables
=
{}
for
cb_name
,
cb
in
self
.
callbacks
.
items
():
variables
[
cb_name
]
=
{}
for
k
,
v
in
cb
.
__dict__
.
items
():
if
not
isinstance
(
v
,
(
int
,
float
)):
continue
variables
[
cb_name
][
k
]
=
v
with
open
(
self
.
callbacks_backup_path
,
"w"
)
as
f
:
json
.
dump
(
variables
,
f
,
indent
=
4
,
sort_keys
=
True
)
def
restore
(
self
):
if
not
os
.
path
.
isfile
(
self
.
callbacks_backup_path
):
return
False
with
open
(
self
.
callbacks_backup_path
,
"r"
)
as
f
:
variables
=
json
.
load
(
f
)
for
cb_name
,
cb
in
self
.
callbacks
.
items
():
if
cb_name
not
in
variables
:
continue
for
k
,
v
in
cb
.
__dict__
.
items
():
if
k
in
variables
[
cb_name
]:
cb
.
__dict__
[
k
]
=
variables
[
cb_name
][
k
]
return
True
def
on_train_begin
(
self
,
logs
=
None
):
super
().
on_train_begin
(
logs
=
logs
)
if
self
.
restore
():
print
(
f
"Restored callbacks from
{
self
.
callbacks_backup_path
}
"
)
else
:
print
(
"Did not restore callbacks"
)
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
super
().
on_epoch_end
(
epoch
,
logs
=
logs
)
self
.
backup
()
def
on_train_end
(
self
,
logs
=
None
):
# do not delete backups
pass
def
add_backup_callback
(
callbacks
,
backup_dir
,
**
kwargs
):
"""Adds a backup callback to your callbacks to restore the training process
if it is interrupted.
.. warning::
This function is experimental and may be removed or changed in future.
Examples
--------
>>> CHECKPOINT = "checkpoints"
>>> callbacks = {
... "best": tf.keras.callbacks.ModelCheckpoint(
... f"{CHECKPOINT}/best",
... monitor="val_acc",
... save_best_only=True,
... mode="max",
... verbose=1,
... ),
... "tensorboard": tf.keras.callbacks.TensorBoard(
... log_dir=f"{CHECKPOINT}/logs",
... update_freq=15,
... write_graph=False,
... ),
... }
>>> callbacks = add_backup_callback(callbacks, f"{CHECKPOINT}/backup")
>>> # callbacks will be a list that can be given to model.fit
>>> isinstance(callbacks, list)
True
"""
if
not
isinstance
(
callbacks
,
dict
):
raise
ValueError
(
"Please provide a dictionary of callbacks where "
"keys are simple names for your callbacks!"
)
cb
=
CustomBackupAndRestore
(
callbacks
=
callbacks
,
backup_dir
=
backup_dir
,
**
kwargs
)
callbacks
=
list
(
callbacks
.
values
())
callbacks
.
append
(
cb
)
return
callbacks
examples/MSCeleba_centerloss_mixed_precision_multi_worker.py
View file @
9e3b0c00
#!/usr/bin/env python
# coding: utf-8
import
os
import
pickle
...
...
@@ -8,12 +7,14 @@ from multiprocessing import cpu_count
import
pkg_resources
import
tensorflow
as
tf
from
bob.learn.tensorflow.losses
import
CenterLoss
,
CenterLossLayer
from
bob.extension
import
rc
from
bob.learn.tensorflow.callbacks
import
add_backup_callback
from
bob.learn.tensorflow.losses
import
CenterLoss
from
bob.learn.tensorflow.losses
import
CenterLossLayer
from
bob.learn.tensorflow.models.inception_resnet_v2
import
InceptionResNetV2
from
bob.learn.tensorflow.utils
import
predict_using_tensors
from
tensorflow.keras
import
layers
from
tensorflow.keras.mixed_precision
import
experimental
as
mixed_precision
from
bob.extension
import
rc
policy
=
mixed_precision
.
Policy
(
"mixed_float16"
)
mixed_precision
.
set_policy
(
policy
)
...
...
@@ -29,9 +30,7 @@ VALIDATION_TF_RECORD_PATHS = (
# there are 2812 samples in the validation set
VALIDATION_SAMPLES
=
2812
CHECKPOINT
=
(
f
"
{
rc
[
'temp'
]
}
/models/inception_v2_batchnorm_rgb_msceleba_mixed_precision"
)
CHECKPOINT
=
f
"
{
rc
[
'temp'
]
}
/models/inception_v2_batchnorm_rgb_msceleba_mixed_precision"
AUTOTUNE
=
tf
.
data
.
experimental
.
AUTOTUNE
TFRECORD_PARALLEL_READ
=
cpu_count
()
...
...
@@ -113,7 +112,8 @@ def prepare_dataset(tf_record_paths, batch_size, shuffle=False, augment=False):
ds
=
ds
.
shuffle
(
SHUFFLE_BUFFER
).
repeat
(
EPOCHS
)
preprocessor
=
get_preprocessor
()
ds
=
ds
.
batch
(
batch_size
).
map
(
partial
(
preprocess
,
preprocessor
,
augment
=
augment
),
num_parallel_calls
=
AUTOTUNE
,
partial
(
preprocess
,
preprocessor
,
augment
=
augment
),
num_parallel_calls
=
AUTOTUNE
,
)
# Use buffered prefecting on all datasets
...
...
@@ -247,33 +247,6 @@ def build_and_compile_model(global_batch_size):
return
model
class
CustomBackupAndRestore
(
tf
.
keras
.
callbacks
.
experimental
.
BackupAndRestore
):
def
__inti__
(
self
,
custom_objects
,
**
kwargs
):
super
().
__inti__
(
**
kwargs
)
self
.
custom_objects
=
custom_objects
self
.
custom_objects_path
=
os
.
path
.
join
(
self
.
backup_dir
,
"custom_objects.pkl"
)
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
super
().
on_epoch_end
(
epoch
,
logs
=
logs
)
# pickle custom objects
with
open
(
self
.
custom_objects_path
,
"wb"
)
as
f
:
pickle
.
dump
(
self
.
custom_objects
,
f
)
def
on_train_begin
(
self
,
logs
=
None
):
super
().
on_train_begin
(
logs
=
logs
)
if
not
os
.
path
.
exists
(
self
.
custom_objects_path
):
return
# load custom objects
with
open
(
self
.
custom_objects_path
,
"rb"
)
as
f
:
self
.
custom_objects
=
pickle
.
load
(
f
)
def
on_train_end
(
self
,
logs
=
None
):
# do not delete backups
pass
def
train_and_evaluate
(
tf_config
):
os
.
environ
[
"TF_CONFIG"
]
=
json
.
dumps
(
tf_config
)
...
...
@@ -312,25 +285,27 @@ def train_and_evaluate(tf_config):
else
:
return
0.001
callbacks
=
[
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
f
"
{
CHECKPOINT
}
/latest"
,
verbose
=
1
),
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
callbacks
=
{
"latest"
:
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
f
"
{
CHECKPOINT
}
/latest"
,
verbose
=
1
),
"best"
:
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
f
"
{
CHECKPOINT
}
/best"
,
monitor
=
val_metric_name
,
save_best_only
=
True
,
mode
=
"max"
,
verbose
=
1
,
),
tf
.
keras
.
callbacks
.
TensorBoard
(
"tensorboard"
:
tf
.
keras
.
callbacks
.
TensorBoard
(
log_dir
=
f
"
{
CHECKPOINT
}
/logs"
,
update_freq
=
15
,
profile_batch
=
"10,50"
),
tf
.
keras
.
callbacks
.
LearningRateScheduler
(
scheduler
,
verbose
=
1
),
# tf.keras.callbacks.ReduceLROnPlateau(
"lr"
:
tf
.
keras
.
callbacks
.
LearningRateScheduler
(
scheduler
,
verbose
=
1
),
#
"lr":
tf.keras.callbacks.ReduceLROnPlateau(
# monitor=val_metric_name, factor=0.2, patience=5, min_lr=0.001
# ),
tf
.
keras
.
callbacks
.
TerminateOnNaN
(),
]
callbacks
.
append
(
CustomBackupAndRestore
(
backup_dir
=
f
"
{
CHECKPOINT
}
/backup"
,
custom_objects
=
callbacks
))
"nan"
:
tf
.
keras
.
callbacks
.
TerminateOnNaN
(),
}
callbacks
=
add_backup_callback
(
callbacks
=
callbacks
,
backup_dir
=
f
"
{
CHECKPOINT
}
/backup"
)
model
.
fit
(
train_ds
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment