From 9e3b0c003b798d6ff09fa8911ed466339ef91623 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Tue, 27 Oct 2020 14:26:23 +0100 Subject: [PATCH] Add a callback that backups and restores other callbacks --- bob/learn/tensorflow/callbacks.py | 100 ++++++++++++++++++ ...centerloss_mixed_precision_multi_worker.py | 61 ++++------- 2 files changed, 118 insertions(+), 43 deletions(-) create mode 100644 bob/learn/tensorflow/callbacks.py diff --git a/bob/learn/tensorflow/callbacks.py b/bob/learn/tensorflow/callbacks.py new file mode 100644 index 00000000..efd97e09 --- /dev/null +++ b/bob/learn/tensorflow/callbacks.py @@ -0,0 +1,100 @@ +import json +import os + +import tensorflow as tf +from tensorflow.keras import callbacks + + +class CustomBackupAndRestore(tf.keras.callbacks.experimental.BackupAndRestore): + """This callback is experimental and might be removed in future. + See :any:`add_backup_callback` + """ + + def __init__(self, callbacks, backup_dir, **kwargs): + super().__init__(backup_dir=backup_dir, **kwargs) + self.callbacks = callbacks + self.callbacks_backup_path = os.path.join(self.backup_dir, "callbacks.json") + + def backup(self): + variables = {} + for cb_name, cb in self.callbacks.items(): + variables[cb_name] = {} + for k, v in cb.__dict__.items(): + if not isinstance(v, (int, float)): + continue + variables[cb_name][k] = v + with open(self.callbacks_backup_path, "w") as f: + json.dump(variables, f, indent=4, sort_keys=True) + + def restore(self): + if not os.path.isfile(self.callbacks_backup_path): + return False + + with open(self.callbacks_backup_path, "r") as f: + variables = json.load(f) + + for cb_name, cb in self.callbacks.items(): + if cb_name not in variables: + continue + for k, v in cb.__dict__.items(): + if k in variables[cb_name]: + cb.__dict__[k] = variables[cb_name][k] + + return True + + def on_train_begin(self, logs=None): + super().on_train_begin(logs=logs) + if self.restore(): + print(f"Restored callbacks from {self.callbacks_backup_path}") + else: + print("Did not restore callbacks") + + def on_epoch_end(self, epoch, logs=None): + super().on_epoch_end(epoch, logs=logs) + self.backup() + + def on_train_end(self, logs=None): + # do not delete backups + pass + + +def add_backup_callback(callbacks, backup_dir, **kwargs): + """Adds a backup callback to your callbacks to restore the training process + if it is interrupted. + + .. warning:: + + This function is experimental and may be removed or changed in future. + + Examples + -------- + + >>> CHECKPOINT = "checkpoints" + >>> callbacks = { + ... "best": tf.keras.callbacks.ModelCheckpoint( + ... f"{CHECKPOINT}/best", + ... monitor="val_acc", + ... save_best_only=True, + ... mode="max", + ... verbose=1, + ... ), + ... "tensorboard": tf.keras.callbacks.TensorBoard( + ... log_dir=f"{CHECKPOINT}/logs", + ... update_freq=15, + ... write_graph=False, + ... ), + ... } + >>> callbacks = add_backup_callback(callbacks, f"{CHECKPOINT}/backup") + >>> # callbacks will be a list that can be given to model.fit + >>> isinstance(callbacks, list) + True + """ + if not isinstance(callbacks, dict): + raise ValueError( + "Please provide a dictionary of callbacks where " + "keys are simple names for your callbacks!" + ) + cb = CustomBackupAndRestore(callbacks=callbacks, backup_dir=backup_dir, **kwargs) + callbacks = list(callbacks.values()) + callbacks.append(cb) + return callbacks diff --git a/examples/MSCeleba_centerloss_mixed_precision_multi_worker.py b/examples/MSCeleba_centerloss_mixed_precision_multi_worker.py index 97fab2a5..bd4b7e53 100644 --- a/examples/MSCeleba_centerloss_mixed_precision_multi_worker.py +++ b/examples/MSCeleba_centerloss_mixed_precision_multi_worker.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# coding: utf-8 import os import pickle @@ -8,12 +7,14 @@ from multiprocessing import cpu_count import pkg_resources import tensorflow as tf -from bob.learn.tensorflow.losses import CenterLoss, CenterLossLayer +from bob.extension import rc +from bob.learn.tensorflow.callbacks import add_backup_callback +from bob.learn.tensorflow.losses import CenterLoss +from bob.learn.tensorflow.losses import CenterLossLayer from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2 from bob.learn.tensorflow.utils import predict_using_tensors from tensorflow.keras import layers from tensorflow.keras.mixed_precision import experimental as mixed_precision -from bob.extension import rc policy = mixed_precision.Policy("mixed_float16") mixed_precision.set_policy(policy) @@ -29,9 +30,7 @@ VALIDATION_TF_RECORD_PATHS = ( # there are 2812 samples in the validation set VALIDATION_SAMPLES = 2812 -CHECKPOINT = ( - f"{rc['temp']}/models/inception_v2_batchnorm_rgb_msceleba_mixed_precision" -) +CHECKPOINT = f"{rc['temp']}/models/inception_v2_batchnorm_rgb_msceleba_mixed_precision" AUTOTUNE = tf.data.experimental.AUTOTUNE TFRECORD_PARALLEL_READ = cpu_count() @@ -113,7 +112,8 @@ def prepare_dataset(tf_record_paths, batch_size, shuffle=False, augment=False): ds = ds.shuffle(SHUFFLE_BUFFER).repeat(EPOCHS) preprocessor = get_preprocessor() ds = ds.batch(batch_size).map( - partial(preprocess, preprocessor, augment=augment), num_parallel_calls=AUTOTUNE, + partial(preprocess, preprocessor, augment=augment), + num_parallel_calls=AUTOTUNE, ) # Use buffered prefecting on all datasets @@ -247,33 +247,6 @@ def build_and_compile_model(global_batch_size): return model -class CustomBackupAndRestore(tf.keras.callbacks.experimental.BackupAndRestore): - def __inti__(self, custom_objects, **kwargs): - super().__inti__(**kwargs) - self.custom_objects = custom_objects - self.custom_objects_path = os.path.join(self.backup_dir, "custom_objects.pkl") - - def on_epoch_end(self, epoch, logs=None): - super().on_epoch_end(epoch, logs=logs) - - # pickle custom objects - with open(self.custom_objects_path, "wb") as f: - pickle.dump(self.custom_objects, f) - - def on_train_begin(self, logs=None): - super().on_train_begin(logs=logs) - if not os.path.exists(self.custom_objects_path): - return - - # load custom objects - with open(self.custom_objects_path, "rb") as f: - self.custom_objects = pickle.load(f) - - def on_train_end(self, logs=None): - # do not delete backups - pass - - def train_and_evaluate(tf_config): os.environ["TF_CONFIG"] = json.dumps(tf_config) @@ -312,25 +285,27 @@ def train_and_evaluate(tf_config): else: return 0.001 - callbacks = [ - tf.keras.callbacks.ModelCheckpoint(f"{CHECKPOINT}/latest", verbose=1), - tf.keras.callbacks.ModelCheckpoint( + callbacks = { + "latest": tf.keras.callbacks.ModelCheckpoint(f"{CHECKPOINT}/latest", verbose=1), + "best": tf.keras.callbacks.ModelCheckpoint( f"{CHECKPOINT}/best", monitor=val_metric_name, save_best_only=True, mode="max", verbose=1, ), - tf.keras.callbacks.TensorBoard( + "tensorboard": tf.keras.callbacks.TensorBoard( log_dir=f"{CHECKPOINT}/logs", update_freq=15, profile_batch="10,50" ), - tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1), - # tf.keras.callbacks.ReduceLROnPlateau( + "lr": tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1), + # "lr": tf.keras.callbacks.ReduceLROnPlateau( # monitor=val_metric_name, factor=0.2, patience=5, min_lr=0.001 # ), - tf.keras.callbacks.TerminateOnNaN(), - ] - callbacks.append(CustomBackupAndRestore(backup_dir=f"{CHECKPOINT}/backup", custom_objects=callbacks)) + "nan": tf.keras.callbacks.TerminateOnNaN(), + } + callbacks = add_backup_callback( + callbacks=callbacks, backup_dir=f"{CHECKPOINT}/backup" + ) model.fit( train_ds, -- GitLab