From 9e3b0c003b798d6ff09fa8911ed466339ef91623 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Tue, 27 Oct 2020 14:26:23 +0100
Subject: [PATCH] Add a callback that backups and restores other callbacks

---
 bob/learn/tensorflow/callbacks.py             | 100 ++++++++++++++++++
 ...centerloss_mixed_precision_multi_worker.py |  61 ++++-------
 2 files changed, 118 insertions(+), 43 deletions(-)
 create mode 100644 bob/learn/tensorflow/callbacks.py

diff --git a/bob/learn/tensorflow/callbacks.py b/bob/learn/tensorflow/callbacks.py
new file mode 100644
index 00000000..efd97e09
--- /dev/null
+++ b/bob/learn/tensorflow/callbacks.py
@@ -0,0 +1,100 @@
+import json
+import os
+
+import tensorflow as tf
+from tensorflow.keras import callbacks
+
+
+class CustomBackupAndRestore(tf.keras.callbacks.experimental.BackupAndRestore):
+    """This callback is experimental and might be removed in future.
+    See :any:`add_backup_callback`
+    """
+
+    def __init__(self, callbacks, backup_dir, **kwargs):
+        super().__init__(backup_dir=backup_dir, **kwargs)
+        self.callbacks = callbacks
+        self.callbacks_backup_path = os.path.join(self.backup_dir, "callbacks.json")
+
+    def backup(self):
+        variables = {}
+        for cb_name, cb in self.callbacks.items():
+            variables[cb_name] = {}
+            for k, v in cb.__dict__.items():
+                if not isinstance(v, (int, float)):
+                    continue
+                variables[cb_name][k] = v
+        with open(self.callbacks_backup_path, "w") as f:
+            json.dump(variables, f, indent=4, sort_keys=True)
+
+    def restore(self):
+        if not os.path.isfile(self.callbacks_backup_path):
+            return False
+
+        with open(self.callbacks_backup_path, "r") as f:
+            variables = json.load(f)
+
+        for cb_name, cb in self.callbacks.items():
+            if cb_name not in variables:
+                continue
+            for k, v in cb.__dict__.items():
+                if k in variables[cb_name]:
+                    cb.__dict__[k] = variables[cb_name][k]
+
+        return True
+
+    def on_train_begin(self, logs=None):
+        super().on_train_begin(logs=logs)
+        if self.restore():
+            print(f"Restored callbacks from {self.callbacks_backup_path}")
+        else:
+            print("Did not restore callbacks")
+
+    def on_epoch_end(self, epoch, logs=None):
+        super().on_epoch_end(epoch, logs=logs)
+        self.backup()
+
+    def on_train_end(self, logs=None):
+        # do not delete backups
+        pass
+
+
+def add_backup_callback(callbacks, backup_dir, **kwargs):
+    """Adds a backup callback to your callbacks to restore the training process
+    if it is interrupted.
+
+    .. warning::
+
+        This function is experimental and may be removed or changed in future.
+
+    Examples
+    --------
+
+    >>> CHECKPOINT = "checkpoints"
+    >>> callbacks = {
+    ...     "best": tf.keras.callbacks.ModelCheckpoint(
+    ...         f"{CHECKPOINT}/best",
+    ...         monitor="val_acc",
+    ...         save_best_only=True,
+    ...         mode="max",
+    ...         verbose=1,
+    ...     ),
+    ...     "tensorboard": tf.keras.callbacks.TensorBoard(
+    ...         log_dir=f"{CHECKPOINT}/logs",
+    ...         update_freq=15,
+    ...         write_graph=False,
+    ...     ),
+    ... }
+    >>> callbacks = add_backup_callback(callbacks, f"{CHECKPOINT}/backup")
+    >>> # callbacks will be a list that can be given to model.fit
+    >>> isinstance(callbacks, list)
+    True
+    """
+    if not isinstance(callbacks, dict):
+        raise ValueError(
+            "Please provide a dictionary of callbacks where "
+            "keys are simple names for your callbacks!"
+        )
+    cb = CustomBackupAndRestore(callbacks=callbacks, backup_dir=backup_dir, **kwargs)
+    callbacks = list(callbacks.values())
+    callbacks.append(cb)
+    return callbacks
diff --git a/examples/MSCeleba_centerloss_mixed_precision_multi_worker.py b/examples/MSCeleba_centerloss_mixed_precision_multi_worker.py
index 97fab2a5..bd4b7e53 100644
--- a/examples/MSCeleba_centerloss_mixed_precision_multi_worker.py
+++ b/examples/MSCeleba_centerloss_mixed_precision_multi_worker.py
@@ -1,5 +1,4 @@
 #!/usr/bin/env python
-# coding: utf-8
 
 import os
 import pickle
@@ -8,12 +7,14 @@ from multiprocessing import cpu_count
 
 import pkg_resources
 import tensorflow as tf
-from bob.learn.tensorflow.losses import CenterLoss, CenterLossLayer
+from bob.extension import rc
+from bob.learn.tensorflow.callbacks import add_backup_callback
+from bob.learn.tensorflow.losses import CenterLoss
+from bob.learn.tensorflow.losses import CenterLossLayer
 from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2
 from bob.learn.tensorflow.utils import predict_using_tensors
 from tensorflow.keras import layers
 from tensorflow.keras.mixed_precision import experimental as mixed_precision
-from bob.extension import rc
 
 policy = mixed_precision.Policy("mixed_float16")
 mixed_precision.set_policy(policy)
@@ -29,9 +30,7 @@ VALIDATION_TF_RECORD_PATHS = (
 # there are 2812 samples in the validation set
 VALIDATION_SAMPLES = 2812
 
-CHECKPOINT = (
-    f"{rc['temp']}/models/inception_v2_batchnorm_rgb_msceleba_mixed_precision"
-)
+CHECKPOINT = f"{rc['temp']}/models/inception_v2_batchnorm_rgb_msceleba_mixed_precision"
 
 AUTOTUNE = tf.data.experimental.AUTOTUNE
 TFRECORD_PARALLEL_READ = cpu_count()
@@ -113,7 +112,8 @@ def prepare_dataset(tf_record_paths, batch_size, shuffle=False, augment=False):
         ds = ds.shuffle(SHUFFLE_BUFFER).repeat(EPOCHS)
     preprocessor = get_preprocessor()
     ds = ds.batch(batch_size).map(
-        partial(preprocess, preprocessor, augment=augment), num_parallel_calls=AUTOTUNE,
+        partial(preprocess, preprocessor, augment=augment),
+        num_parallel_calls=AUTOTUNE,
     )
 
     # Use buffered prefecting on all datasets
@@ -247,33 +247,6 @@ def build_and_compile_model(global_batch_size):
     return model
 
 
-class CustomBackupAndRestore(tf.keras.callbacks.experimental.BackupAndRestore):
-    def __inti__(self, custom_objects, **kwargs):
-        super().__inti__(**kwargs)
-        self.custom_objects = custom_objects
-        self.custom_objects_path = os.path.join(self.backup_dir, "custom_objects.pkl")
-
-    def on_epoch_end(self, epoch, logs=None):
-        super().on_epoch_end(epoch, logs=logs)
-
-        # pickle custom objects
-        with open(self.custom_objects_path, "wb") as f:
-            pickle.dump(self.custom_objects, f)
-
-    def on_train_begin(self, logs=None):
-        super().on_train_begin(logs=logs)
-        if not os.path.exists(self.custom_objects_path):
-            return
-
-        # load custom objects
-        with open(self.custom_objects_path, "rb") as f:
-            self.custom_objects = pickle.load(f)
-
-    def on_train_end(self, logs=None):
-        # do not delete backups
-        pass
-
-
 def train_and_evaluate(tf_config):
     os.environ["TF_CONFIG"] = json.dumps(tf_config)
 
@@ -312,25 +285,27 @@ def train_and_evaluate(tf_config):
         else:
             return 0.001
 
-    callbacks = [
-        tf.keras.callbacks.ModelCheckpoint(f"{CHECKPOINT}/latest", verbose=1),
-        tf.keras.callbacks.ModelCheckpoint(
+    callbacks = {
+        "latest": tf.keras.callbacks.ModelCheckpoint(f"{CHECKPOINT}/latest", verbose=1),
+        "best": tf.keras.callbacks.ModelCheckpoint(
             f"{CHECKPOINT}/best",
             monitor=val_metric_name,
             save_best_only=True,
             mode="max",
             verbose=1,
         ),
-        tf.keras.callbacks.TensorBoard(
+        "tensorboard": tf.keras.callbacks.TensorBoard(
             log_dir=f"{CHECKPOINT}/logs", update_freq=15, profile_batch="10,50"
         ),
-        tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1),
-        # tf.keras.callbacks.ReduceLROnPlateau(
+        "lr": tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1),
+        # "lr": tf.keras.callbacks.ReduceLROnPlateau(
         #     monitor=val_metric_name, factor=0.2, patience=5, min_lr=0.001
         # ),
-        tf.keras.callbacks.TerminateOnNaN(),
-    ]
-    callbacks.append(CustomBackupAndRestore(backup_dir=f"{CHECKPOINT}/backup", custom_objects=callbacks))
+        "nan": tf.keras.callbacks.TerminateOnNaN(),
+    }
+    callbacks = add_backup_callback(
+        callbacks=callbacks, backup_dir=f"{CHECKPOINT}/backup"
+    )
 
     model.fit(
         train_ds,
-- 
GitLab