diff --git a/bob/learn/tensorflow/utils/keras.py b/bob/learn/tensorflow/utils/keras.py
index 0181db0cc191f48d9790086176b70e7209e67720..c74021620069100b909d82a2c77998b976e4643a 100644
--- a/bob/learn/tensorflow/utils/keras.py
+++ b/bob/learn/tensorflow/utils/keras.py
@@ -50,13 +50,24 @@ def apply_trainable_variables_on_keras_model(model, trainable_variables, mode):
             layer.trainable = trainable
 
 
-def restore_model_variables_from_checkpoint(model, checkpoint, session=None):
+def _create_var_map(variables, normalizer=None):
+    if normalizer is None:
+
+        def normalizer(name):
+            return name.split(":")[0]
+
+    assignment_map = {normalizer(v.name): v for v in variables}
+    assert len(assignment_map)
+    return assignment_map
+
+
+def restore_model_variables_from_checkpoint(
+    model, checkpoint, session=None, normalizer=None
+):
     if session is None:
         session = tf.keras.backend.get_session()
 
-    # removes duplicates
-    var_list = set(model.variables)
-    assert len(var_list)
+    var_list = _create_var_map(model.variables, normalizer=normalizer)
     saver = tf.train.Saver(var_list=var_list)
     ckpt_state = tf.train.get_checkpoint_state(checkpoint)
     logger.info("Loading checkpoint %s", ckpt_state.model_checkpoint_path)
@@ -64,9 +75,267 @@ def restore_model_variables_from_checkpoint(model, checkpoint, session=None):
 
 
 def initialize_model_from_checkpoint(model, checkpoint, normalizer=None):
-    if normalizer is None:
-        def normalizer(name):
-            return name.split(":")[0]
-    assignment_map = {normalizer(v.name): v for v in model.variables}
-    assert len(assignment_map)
+    assignment_map = _create_var_map(model.variables, normalizer=normalizer)
     tf.train.init_from_checkpoint(checkpoint, assignment_map=assignment_map)
+
+
+def model_summary(model, do_print=False):
+    try:
+        from tensorflow.python.keras.utils.layer_utils import count_params
+    except ImportError:
+        from tensorflow_core.python.keras.utils.layer_utils import count_params
+    nest = tf.nest
+
+    if model.__class__.__name__ == "Sequential":
+        sequential_like = True
+    elif not model._is_graph_network:
+        # We treat subclassed models as a simple sequence of layers, for logging
+        # purposes.
+        sequential_like = True
+    else:
+        sequential_like = True
+        nodes_by_depth = model._nodes_by_depth.values()
+        nodes = []
+        for v in nodes_by_depth:
+            if (len(v) > 1) or (
+                len(v) == 1 and len(nest.flatten(v[0].inbound_layers)) > 1
+            ):
+                # if the model has multiple nodes
+                # or if the nodes have multiple inbound_layers
+                # the model is no longer sequential
+                sequential_like = False
+                break
+            nodes += v
+        if sequential_like:
+            # search for shared layers
+            for layer in model.layers:
+                flag = False
+                for node in layer._inbound_nodes:
+                    if node in nodes:
+                        if flag:
+                            sequential_like = False
+                            break
+                        else:
+                            flag = True
+                if not sequential_like:
+                    break
+
+    if sequential_like:
+        # header names for the different log elements
+        to_display = ["Layer (type)", "Details", "Output Shape", "Number of Parameters"]
+    else:
+        # header names for the different log elements
+        to_display = [
+            "Layer (type)",
+            "Details",
+            "Output Shape",
+            "Number of Parameters",
+            "Connected to",
+        ]
+        relevant_nodes = []
+        for v in model._nodes_by_depth.values():
+            relevant_nodes += v
+
+    rows = [to_display]
+
+    def print_row(fields):
+        for i, v in enumerate(fields):
+            if isinstance(v, int):
+                fields[i] = f"{v:,}"
+        rows.append(fields)
+
+    def layer_details(layer):
+        cls_name = layer.__class__.__name__
+        details = []
+        if "Conv" in cls_name and "ConvBlock" not in cls_name:
+            details += [f"filters={layer.filters}"]
+            details += [f"kernel_size={layer.kernel_size}"]
+
+        if "Pool" in cls_name and "Global" not in cls_name:
+            details += [f"pool_size={layer.pool_size}"]
+
+        if (
+            "Conv" in cls_name
+            and "ConvBlock" not in cls_name
+            or "Pool" in cls_name
+            and "Global" not in cls_name
+        ):
+            details += [f"strides={layer.strides}"]
+
+        if (
+            "ZeroPad" in cls_name
+            or cls_name in ("Conv1D", "Conv2D", "Conv3D")
+            or "Pool" in cls_name
+            and "Global" not in cls_name
+        ):
+            details += [f"padding={layer.padding}"]
+
+        if "Cropping" in cls_name:
+            details += [f"cropping={layer.cropping}"]
+
+        if cls_name == "Dense":
+            details += [f"units={layer.units}"]
+
+        if cls_name in ("Conv1D", "Conv2D", "Conv3D") or cls_name == "Dense":
+            act = layer.activation.__name__
+            if act != "linear":
+                details += [f"activation={act}"]
+
+        if cls_name == "Dropout":
+            details += [f"drop_rate={layer.rate}"]
+
+        if cls_name == "Concatenate":
+            details += [f"axis={layer.axis}"]
+
+        if cls_name == "Activation":
+            act = layer.get_config()["activation"]
+            details += [f"activation={act}"]
+
+        if "InceptionModule" in cls_name:
+            details += [f"b1_c1={layer.filter_1x1}"]
+            details += [f"b2_c1={layer.filter_3x3_reduce}"]
+            details += [f"b2_c2={layer.filter_3x3}"]
+            details += [f"b3_c1={layer.filter_5x5_reduce}"]
+            details += [f"b3_c2={layer.filter_5x5}"]
+            details += [f"b4_c1={layer.pool_proj}"]
+
+        if cls_name == "LRN":
+            details += [f"depth_radius={layer.depth_radius}"]
+            details += [f"alpha={layer.alpha}"]
+            details += [f"beta={layer.beta}"]
+
+        if cls_name == "ConvBlock":
+            details += [f"filters={layer.num_filters}"]
+            details += [f"bottleneck={layer.bottleneck}"]
+            details += [f"dropout_rate={layer.dropout_rate}"]
+
+        if cls_name == "DenseBlock":
+            details += [f"layers={layer.num_layers}"]
+            details += [f"growth_rate={layer.growth_rate}"]
+            details += [f"bottleneck={layer.bottleneck}"]
+            details += [f"dropout_rate={layer.dropout_rate}"]
+
+        if cls_name == "TransitionBlock":
+            details += [f"filters={layer.num_filters}"]
+
+        if cls_name == "InceptionA":
+            details += [f"pool_filters={layer.pool_filters}"]
+
+        if cls_name == "InceptionResnetBlock":
+            details += [f"block_type={layer.block_type}"]
+            details += [f"scale={layer.scale}"]
+            details += [f"n={layer.n}"]
+
+        if cls_name == "ReductionA":
+            details += [f"k={layer.k}"]
+            details += [f"kl={layer.kl}"]
+            details += [f"km={layer.km}"]
+            details += [f"n={layer.n}"]
+
+        if cls_name == "ReductionB":
+            details += [f"k={layer.k}"]
+            details += [f"kl={layer.kl}"]
+            details += [f"km={layer.km}"]
+            details += [f"n={layer.n}"]
+            details += [f"no={layer.no}"]
+            details += [f"p={layer.p}"]
+            details += [f"pq={layer.pq}"]
+
+        if cls_name == "ScaledResidual":
+            details += [f"scale={layer.scale}"]
+
+        return ", ".join(details)
+
+    def print_layer_summary(layer):
+        """Prints a summary for a single layer.
+
+        Arguments:
+                layer: target layer.
+        """
+        try:
+            output_shape = layer.output_shape
+        except AttributeError:
+            output_shape = "multiple"
+        except RuntimeError:  # output_shape unknown in Eager mode.
+            output_shape = "?"
+        name = layer.name
+        cls_name = layer.__class__.__name__
+        fields = [
+            name + " (" + cls_name + ")",
+            layer_details(layer),
+            output_shape,
+            layer.count_params(),
+        ]
+        print_row(fields)
+
+    def print_layer_summary_with_connections(layer):
+        """Prints a summary for a single layer (including topological connections).
+
+        Arguments:
+                layer: target layer.
+        """
+        try:
+            output_shape = layer.output_shape
+        except AttributeError:
+            output_shape = "multiple"
+        connections = []
+        for node in layer._inbound_nodes:
+            if relevant_nodes and node not in relevant_nodes:
+                # node is not part of the current network
+                continue
+
+            for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
+                connections.append(
+                    "{}[{}][{}]".format(inbound_layer.name, node_index, tensor_index)
+                )
+
+        name = layer.name
+        cls_name = layer.__class__.__name__
+        if not connections:
+            first_connection = ""
+        else:
+            first_connection = connections[0]
+        fields = [
+            name + " (" + cls_name + ")",
+            layer_details(layer),
+            output_shape,
+            layer.count_params(),
+            first_connection,
+        ]
+        print_row(fields)
+        if len(connections) > 1:
+            for i in range(1, len(connections)):
+                fields = ["", "", "", "", connections[i]]
+                print_row(fields)
+
+    layers = model.layers
+    for i in range(len(layers)):
+        if sequential_like:
+            print_layer_summary(layers[i])
+        else:
+            print_layer_summary_with_connections(layers[i])
+
+    model._check_trainable_weights_consistency()
+    if hasattr(model, "_collected_trainable_weights"):
+        trainable_count = count_params(model._collected_trainable_weights)
+    else:
+        trainable_count = count_params(model.trainable_weights)
+
+    non_trainable_count = count_params(model.non_trainable_weights)
+
+    print_row([])
+    print_row(
+        [
+            "Model",
+            f"Parameters: total={trainable_count + non_trainable_count:,}, trainable={trainable_count:,}",
+        ]
+    )
+
+    if do_print:
+        from tabulate import tabulate
+
+        print()
+        print(tabulate(rows, headers="firstrow"))
+        print()
+
+    return rows