Skip to content
Snippets Groups Projects
Commit 400c5d99 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add a callback that backups and restores other callbacks

parent cf5783f4
Branches
No related tags found
1 merge request!87WIP: Updates
Pipeline #44584 failed
import json
import os
import tensorflow as tf
from tensorflow.keras import callbacks
class CustomBackupAndRestore(tf.keras.callbacks.experimental.BackupAndRestore):
"""This callback is experimental and might be removed in future.
See :any:`add_backup_callback`
"""
def __init__(self, callbacks, backup_dir, **kwargs):
super().__init__(backup_dir=backup_dir, **kwargs)
self.callbacks = callbacks
self.callbacks_backup_path = os.path.join(self.backup_dir, "callbacks.json")
def backup(self):
variables = {}
for cb_name, cb in self.callbacks.items():
variables[cb_name] = {}
for k, v in cb.__dict__.items():
if not isinstance(v, (int, float)):
continue
variables[cb_name][k] = v
with open(self.callbacks_backup_path, "w") as f:
json.dump(variables, f, indent=4, sort_keys=True)
def restore(self):
if not os.path.isfile(self.callbacks_backup_path):
return False
with open(self.callbacks_backup_path, "r") as f:
variables = json.load(f)
for cb_name, cb in self.callbacks.items():
if cb_name not in variables:
continue
for k, v in cb.__dict__.items():
if k in variables[cb_name]:
cb.__dict__[k] = variables[cb_name][k]
return True
def on_train_begin(self, logs=None):
super().on_train_begin(logs=logs)
if self.restore():
print(f"Restored callbacks from {self.callbacks_backup_path}")
else:
print("Did not restore callbacks")
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs=logs)
self.backup()
def on_train_end(self, logs=None):
# do not delete backups
pass
def add_backup_callback(callbacks, backup_dir, **kwargs):
"""Adds a backup callback to your callbacks to restore the training process
if it is interrupted.
.. warning::
This function is experimental and may be removed or changed in future.
Examples
--------
>>> CHECKPOINT = "checkpoints"
>>> callbacks = {
... "best": tf.keras.callbacks.ModelCheckpoint(
... f"{CHECKPOINT}/best",
... monitor="val_acc",
... save_best_only=True,
... mode="max",
... verbose=1,
... ),
... "tensorboard": tf.keras.callbacks.TensorBoard(
... log_dir=f"{CHECKPOINT}/logs",
... update_freq=15,
... write_graph=False,
... ),
... }
>>> callbacks = add_backup_callback(callbacks, f"{CHECKPOINT}/backup")
>>> # callbacks will be a list that can be given to model.fit
>>> isinstance(callbacks, list)
True
"""
if not isinstance(callbacks, dict):
raise ValueError(
"Please provide a dictionary of callbacks where "
"keys are simple names for your callbacks!"
)
cb = CustomBackupAndRestore(callbacks=callbacks, backup_dir=backup_dir, **kwargs)
callbacks = list(callbacks.values())
callbacks.append(cb)
return callbacks
#!/usr/bin/env python
# coding: utf-8
import os
import pickle
......@@ -8,12 +7,14 @@ from multiprocessing import cpu_count
import pkg_resources
import tensorflow as tf
from bob.learn.tensorflow.losses import CenterLoss, CenterLossLayer
from bob.extension import rc
from bob.learn.tensorflow.callbacks import add_backup_callback
from bob.learn.tensorflow.losses import CenterLoss
from bob.learn.tensorflow.losses import CenterLossLayer
from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2
from bob.learn.tensorflow.utils import predict_using_tensors
from tensorflow.keras import layers
from tensorflow.keras.mixed_precision import experimental as mixed_precision
from bob.extension import rc
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_policy(policy)
......@@ -29,9 +30,7 @@ VALIDATION_TF_RECORD_PATHS = (
# there are 2812 samples in the validation set
VALIDATION_SAMPLES = 2812
CHECKPOINT = (
f"{rc['temp']}/models/inception_v2_batchnorm_rgb_msceleba_mixed_precision"
)
CHECKPOINT = f"{rc['temp']}/models/inception_v2_batchnorm_rgb_msceleba_mixed_precision"
AUTOTUNE = tf.data.experimental.AUTOTUNE
TFRECORD_PARALLEL_READ = cpu_count()
......@@ -113,7 +112,8 @@ def prepare_dataset(tf_record_paths, batch_size, shuffle=False, augment=False):
ds = ds.shuffle(SHUFFLE_BUFFER).repeat(EPOCHS)
preprocessor = get_preprocessor()
ds = ds.batch(batch_size).map(
partial(preprocess, preprocessor, augment=augment), num_parallel_calls=AUTOTUNE,
partial(preprocess, preprocessor, augment=augment),
num_parallel_calls=AUTOTUNE,
)
# Use buffered prefecting on all datasets
......@@ -247,33 +247,6 @@ def build_and_compile_model(global_batch_size):
return model
class CustomBackupAndRestore(tf.keras.callbacks.experimental.BackupAndRestore):
def __inti__(self, custom_objects, **kwargs):
super().__inti__(**kwargs)
self.custom_objects = custom_objects
self.custom_objects_path = os.path.join(self.backup_dir, "custom_objects.pkl")
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs=logs)
# pickle custom objects
with open(self.custom_objects_path, "wb") as f:
pickle.dump(self.custom_objects, f)
def on_train_begin(self, logs=None):
super().on_train_begin(logs=logs)
if not os.path.exists(self.custom_objects_path):
return
# load custom objects
with open(self.custom_objects_path, "rb") as f:
self.custom_objects = pickle.load(f)
def on_train_end(self, logs=None):
# do not delete backups
pass
def train_and_evaluate(tf_config):
os.environ["TF_CONFIG"] = json.dumps(tf_config)
......@@ -312,25 +285,27 @@ def train_and_evaluate(tf_config):
else:
return 0.001
callbacks = [
tf.keras.callbacks.ModelCheckpoint(f"{CHECKPOINT}/latest", verbose=1),
tf.keras.callbacks.ModelCheckpoint(
callbacks = {
"latest": tf.keras.callbacks.ModelCheckpoint(f"{CHECKPOINT}/latest", verbose=1),
"best": tf.keras.callbacks.ModelCheckpoint(
f"{CHECKPOINT}/best",
monitor=val_metric_name,
save_best_only=True,
mode="max",
verbose=1,
),
tf.keras.callbacks.TensorBoard(
"tensorboard": tf.keras.callbacks.TensorBoard(
log_dir=f"{CHECKPOINT}/logs", update_freq=15, profile_batch="10,50"
),
tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1),
# tf.keras.callbacks.ReduceLROnPlateau(
"lr": tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1),
# "lr": tf.keras.callbacks.ReduceLROnPlateau(
# monitor=val_metric_name, factor=0.2, patience=5, min_lr=0.001
# ),
tf.keras.callbacks.TerminateOnNaN(),
]
callbacks.append(CustomBackupAndRestore(backup_dir=f"{CHECKPOINT}/backup", custom_objects=callbacks))
"nan": tf.keras.callbacks.TerminateOnNaN(),
}
callbacks = add_backup_callback(
callbacks=callbacks, backup_dir=f"{CHECKPOINT}/backup"
)
model.fit(
train_ds,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment