diff --git a/bob/learn/tensorflow/configs/MirroredStrategy.py b/bob/learn/tensorflow/configs/MirroredStrategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..5273ac3465a70d7c7b7371d7637b19c4428ec536
--- /dev/null
+++ b/bob/learn/tensorflow/configs/MirroredStrategy.py
@@ -0,0 +1,9 @@
+import tensorflow as tf
+
+
+def strategy_fn():
+    print("Creating MirroredStrategy strategy.")
+    strategy = tf.distribute.MirroredStrategy()
+    print("MirroredStrategy strategy created.")
+    print("Number of devices: {}".format(strategy.num_replicas_in_sync))
+    return strategy
diff --git a/bob/learn/tensorflow/configs/MultiWorkerMirroredStrategy.py b/bob/learn/tensorflow/configs/MultiWorkerMirroredStrategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..89d82405f25f97a305e70c649fdbf4f2232610de
--- /dev/null
+++ b/bob/learn/tensorflow/configs/MultiWorkerMirroredStrategy.py
@@ -0,0 +1,12 @@
+import tensorflow as tf
+
+
+def strategy_fn():
+    print("Creating MultiWorkerMirroredStrategy strategy.")
+    strategy = tf.distribute.MultiWorkerMirroredStrategy(
+        communication_options=tf.distribute.experimental.CommunicationOptions(
+            implementation=tf.distribute.experimental.CollectiveCommunication.NCCL
+        )
+    )
+    print("MultiWorkerMirroredStrategy strategy created.")
+    return strategy
diff --git a/bob/learn/tensorflow/data/generator.py b/bob/learn/tensorflow/data/generator.py
index 764996367e664e456620de151b60fa1f4a4eee04..7cdc63e64ea7600aea2755ec0c99dd17d33f078b 100644
--- a/bob/learn/tensorflow/data/generator.py
+++ b/bob/learn/tensorflow/data/generator.py
@@ -37,7 +37,7 @@ class Generator:
         reader,
         multiple_samples=False,
         shuffle_on_epoch_end=False,
-        **kwargs
+        **kwargs,
     ):
         super().__init__(**kwargs)
         self.reader = reader
@@ -65,6 +65,7 @@ class Generator:
         dataset = tf.data.Dataset.from_tensors(dlk)
         self._output_types = tf.compat.v1.data.get_output_types(dataset)
         self._output_shapes = tf.compat.v1.data.get_output_shapes(dataset)
+        self._element_spec = dataset.element_spec
 
         logger.info(
             "Initializing a dataset with %d %s and %s types and %s shapes",
@@ -84,6 +85,11 @@ class Generator:
         "The shapes of the returned samples"
         return self._output_shapes
 
+    @property
+    def element_spec(self):
+        "The type specification of an element of the dataset"
+        return self._element_spec
+
     def __call__(self):
         """A generator function that when called will yield the samples.
 
@@ -106,7 +112,13 @@ class Generator:
             random.shuffle(self.samples)
 
 
-def dataset_using_generator(samples, reader, **kwargs):
+def dataset_using_generator(
+    samples,
+    reader,
+    multiple_samples=False,
+    shuffle_on_epoch_end=False,
+    **kwargs,
+):
     """
     A generator class which wraps samples so that they can
     be used with tf.data.Dataset.from_generator
@@ -128,8 +140,14 @@ def dataset_using_generator(samples, reader, **kwargs):
         A tf.data.Dataset
     """
 
-    generator = Generator(samples, reader, **kwargs)
+    generator = Generator(
+        samples,
+        reader,
+        multiple_samples=multiple_samples,
+        shuffle_on_epoch_end=shuffle_on_epoch_end,
+        **kwargs,
+    )
     dataset = tf.data.Dataset.from_generator(
-        generator, generator.output_types, generator.output_shapes
+        generator, output_signature=generator.element_spec
     )
     return dataset
diff --git a/bob/learn/tensorflow/layers.py b/bob/learn/tensorflow/layers.py
index 337858788de27daa927b7db4311f7c33f1fa6b3f..20ee4d58b5cec9108e16cb52bf7af11e2713a8e8 100644
--- a/bob/learn/tensorflow/layers.py
+++ b/bob/learn/tensorflow/layers.py
@@ -263,7 +263,14 @@ from tensorflow.keras.layers import Flatten
 
 
 def add_bottleneck(
-    model, bottleneck_size=128, dropout_rate=0.2, w_decay=5e-4, use_bias=True
+    model,
+    bottleneck_size=128,
+    dropout_rate=0.2,
+    w_decay=5e-4,
+    use_bias=True,
+    batch_norm_decay=0.99,
+    batch_norm_epsilon=1e-3,
+    batch_norm_scale=True,
 ):
     """
     Amend a bottleneck layer to a Keras Model
@@ -286,7 +293,13 @@ def add_bottleneck(
     else:
         new_model = model
 
-    new_model.add(BatchNormalization())
+    new_model.add(
+        BatchNormalization(
+            momentum=batch_norm_decay,
+            epsilon=batch_norm_epsilon,
+            scale=batch_norm_scale,
+        )
+    )
     new_model.add(Dropout(dropout_rate, name="Dropout"))
     new_model.add(Flatten())
 
@@ -300,11 +313,19 @@ def add_bottleneck(
             bottleneck_size,
             use_bias=use_bias,
             kernel_regularizer=regularizer,
+            dtype="float32",
         )
     )
 
-    new_model.add(BatchNormalization(axis=-1, name="embeddings"))
-    # new_model.add(BatchNormalization())
+    new_model.add(
+        BatchNormalization(
+            name="embeddings",
+            momentum=batch_norm_decay,
+            epsilon=batch_norm_epsilon,
+            scale=batch_norm_scale,
+            dtype="float32",
+        )
+    )
 
     return new_model
 
diff --git a/bob/learn/tensorflow/models/__init__.py b/bob/learn/tensorflow/models/__init__.py
index 333ab1f4f7446e182a5ce7348eb1f1eb3a747cac..152be75fd576ac2a35e0a5f9560209163e1fccef 100644
--- a/bob/learn/tensorflow/models/__init__.py
+++ b/bob/learn/tensorflow/models/__init__.py
@@ -1,11 +1,9 @@
 from .alexnet import AlexNet_simplified
 from .arcface import ArcFaceLayer
 from .arcface import ArcFaceLayer3Penalties
-from .arcface import ArcFaceModel
 from .densenet import DeepPixBiS
 from .densenet import DenseNet
 from .densenet import densenet161  # noqa: F401
-from .embedding_validation import EmbeddingValidation
 from .mine import MineModel
 from .resnet50_modified import resnet50_modified  # noqa: F401
 from .resnet50_modified import resnet101_modified  # noqa: F401
@@ -34,7 +32,5 @@ __appropriate__(
     MineModel,
     ArcFaceLayer,
     ArcFaceLayer3Penalties,
-    ArcFaceModel,
-    EmbeddingValidation,
 )
 __all__ = [_ for _ in dir() if not _.startswith("_")]
diff --git a/bob/learn/tensorflow/models/arcface.py b/bob/learn/tensorflow/models/arcface.py
index 5c856900c8d5fd3237da1faf2d1c45237389a8fe..27cf85f4574ca0c4335310b3ec58803ec6a638c0 100644
--- a/bob/learn/tensorflow/models/arcface.py
+++ b/bob/learn/tensorflow/models/arcface.py
@@ -2,48 +2,6 @@ import math
 
 import tensorflow as tf
 
-from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings
-
-from .embedding_validation import EmbeddingValidation
-
-
-class ArcFaceModel(EmbeddingValidation):
-    def train_step(self, data):
-        X, y = data
-
-        with tf.GradientTape() as tape:
-
-            logits, _ = self((X, y), training=True)
-            loss = self.compiled_loss(
-                y, logits, sample_weight=None, regularization_losses=self.losses
-            )
-            reg_loss = tf.reduce_sum(self.losses)
-            total_loss = loss + reg_loss
-
-        trainable_vars = self.trainable_variables
-
-        self.optimizer.minimize(total_loss, trainable_vars, tape=tape)
-
-        self.compiled_metrics.update_state(y, logits, sample_weight=None)
-
-        tf.summary.scalar("arc_face_loss", data=loss, step=self._train_counter)
-        tf.summary.scalar("total_loss", data=total_loss, step=self._train_counter)
-
-        self.train_loss(loss)
-        return {m.name: m.result() for m in self.metrics + [self.train_loss]}
-
-    def test_step(self, data):
-        """
-        Test Step
-        """
-
-        images, labels = data
-
-        # No worries, labels not used in validation
-        _, embeddings = self((images, labels), training=False)
-        self.validation_acc(accuracy_from_embeddings(labels, embeddings))
-        return {m.name: m.result() for m in [self.validation_acc]}
-
 
 class ArcFaceLayer(tf.keras.layers.Layer):
     """
@@ -69,18 +27,29 @@ class ArcFaceLayer(tf.keras.layers.Layer):
          If `True`, uses arcface loss. If `False`, it's a regular dense layer
     """
 
-    def __init__(self, n_classes=10, s=30, m=0.5, arc=True):
-        super(ArcFaceLayer, self).__init__(name="arc_face_logits")
+    def __init__(
+        # don't forget to fix get_config when you change init params
+        self,
+        n_classes,
+        s=30,
+        m=0.5,
+        arc=True,
+        kernel_initializer=None,
+        name="arc_face_logits",
+        **kwargs,
+    ):
+        super().__init__(name=name, **kwargs)
         self.n_classes = n_classes
         self.s = s
         self.arc = arc
         self.m = m
+        self.kernel_initializer = kernel_initializer
 
     def build(self, input_shape):
         super(ArcFaceLayer, self).build(input_shape[0])
         shape = [input_shape[-1], self.n_classes]
 
-        self.W = self.add_variable("W", shape=shape)
+        self.W = self.add_weight("W", shape=shape, initializer=self.kernel_initializer)
 
         self.cos_m = tf.identity(math.cos(self.m), name="cos_m")
         self.sin_m = tf.identity(math.sin(self.m), name="sin_m")
@@ -100,14 +69,21 @@ class ArcFaceLayer(tf.keras.layers.Layer):
             sin_yi = tf.clip_by_value(tf.math.sqrt(1 - cos_yi ** 2), 0, 1)
 
             # cos(x+m) = cos(x)*cos(m) - sin(x)*sin(m)
-            cos_yi_m = cos_yi * self.cos_m - sin_yi * self.sin_m
+            dtype = cos_yi.dtype
+            cos_m = tf.cast(self.cos_m, dtype=dtype)
+            sin_m = tf.cast(self.sin_m, dtype=dtype)
+            th = tf.cast(self.th, dtype=dtype)
+            mm = tf.cast(self.mm, dtype=dtype)
 
-            cos_yi_m = tf.where(cos_yi > self.th, cos_yi_m, cos_yi - self.mm)
+            cos_yi_m = cos_yi * cos_m - sin_yi * sin_m
+
+            cos_yi_m = tf.where(cos_yi > th, cos_yi_m, cos_yi - mm)
 
             # Preparing the hot-output
             one_hot = tf.one_hot(
                 tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
             )
+            one_hot = tf.cast(one_hot, dtype=dtype)
 
             logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
             logits = self.s * logits
@@ -116,6 +92,21 @@ class ArcFaceLayer(tf.keras.layers.Layer):
 
         return logits
 
+    def get_config(self):
+        config = dict(super().get_config())
+        config.update(
+            {
+                "n_classes": self.n_classes,
+                "s": self.s,
+                "arc": self.arc,
+                "m": self.m,
+                "kernel_initializer": tf.keras.initializers.serialize(
+                    self.kernel_initializer
+                ),
+            }
+        )
+        return config
+
 
 class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
     """
@@ -126,8 +117,17 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
       :math:`s(cos(m_1\\theta_i + m_2) -m_3`
     """
 
-    def __init__(self, n_classes=10, s=30, m1=0.5, m2=0.5, m3=0.5):
-        super(ArcFaceLayer3Penalties, self).__init__(name="arc_face_logits")
+    def __init__(
+        self,
+        n_classes=10,
+        s=30,
+        m1=0.5,
+        m2=0.5,
+        m3=0.5,
+        name="arc_face_logits",
+        **kwargs,
+    ):
+        super().__init__(name=name, **kwargs)
         self.n_classes = n_classes
         self.s = s
         self.m1 = m1
@@ -170,3 +170,16 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
 
         logits = self.s * logits
         return logits
+
+    def get_config(self):
+        config = dict(super().get_config())
+        config.update(
+            {
+                "n_classes": self.n_classes,
+                "s": self.s,
+                "m1": self.m1,
+                "m2": self.m2,
+                "m3": self.m3,
+            }
+        )
+        return config
diff --git a/bob/learn/tensorflow/models/embedding_validation.py b/bob/learn/tensorflow/models/embedding_validation.py
deleted file mode 100644
index beb4498bc3e1465c496f367bc1370b95f6a36168..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/models/embedding_validation.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import tensorflow as tf
-
-from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings
-
-
-class EmbeddingValidation(tf.keras.Model):
-    """
-    Use this model if the validation step should validate the accuracy with respect to embeddings.
-
-    In this model, the `test_step` runs the function `bob.learn.tensorflow.metrics.embedding_accuracy.accuracy_from_embeddings`
-    """
-
-    def compile(
-        self,
-        single_precision=False,
-        **kwargs,
-    ):
-        """
-        Compile
-        """
-        super().compile(**kwargs)
-        self.train_loss = tf.keras.metrics.Mean(name="accuracy")
-        self.validation_acc = tf.keras.metrics.Mean(name="accuracy")
-
-    def train_step(self, data):
-        """
-        Train Step
-        """
-
-        X, y = data
-
-        with tf.GradientTape() as tape:
-            logits, _ = self(X, training=True)
-            loss = self.loss(y, logits)
-
-        # trainable_vars = self.trainable_variables
-
-        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
-
-        self.compiled_metrics.update_state(y, logits, sample_weight=None)
-        self.train_loss(loss)
-
-        tf.summary.scalar("training_loss", data=loss, step=self._train_counter)
-
-        return {m.name: m.result() for m in self.metrics + [self.train_loss]}
-
-        # self.optimizer.apply_gradients(zip(gradients, trainable_vars))
-        # self.train_loss(loss)
-        # return {m.name: m.result() for m in [self.train_loss]}
-
-    def test_step(self, data):
-        """
-        Test Step
-        """
-
-        images, labels = data
-        logits, prelogits = self(images, training=False)
-        self.validation_acc(accuracy_from_embeddings(labels, prelogits))
-        return {m.name: m.result() for m in [self.validation_acc]}
diff --git a/bob/learn/tensorflow/models/iresnet.py b/bob/learn/tensorflow/models/iresnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..335db66de42905f536c74cd8a732f97e59fabfc1
--- /dev/null
+++ b/bob/learn/tensorflow/models/iresnet.py
@@ -0,0 +1,251 @@
+"""iResNet models for Keras.
+Adapted from insightface/recognition/arcface_torch/backbones/iresnet.py
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+__all__ = ["iresnet18", "iresnet34", "iresnet50", "iresnet100", "iresnet200"]
+
+
+def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4):
+    return tf.keras.regularizers.L2(l2_weight_decay) if use_l2_regularizer else None
+
+
+def _gen_initializer():
+    return tf.keras.initializers.RandomNormal(mean=0, stddev=0.1)
+
+
+def conv3x3(filters, stride=1, groups=1, dilation=1):
+    """3x3 convolution with padding"""
+    return tf.keras.layers.Conv2D(
+        filters,
+        kernel_size=3,
+        strides=stride,
+        padding="same" if dilation else "valid",
+        groups=groups,
+        use_bias=False,
+        dilation_rate=dilation,
+        kernel_initializer=_gen_initializer(),
+    )
+
+
+def conv1x1(filters, stride=1):
+    """1x1 convolution"""
+    return tf.keras.layers.Conv2D(
+        filters,
+        kernel_size=1,
+        strides=stride,
+        padding="valid",
+        use_bias=False,
+        kernel_initializer=_gen_initializer(),
+    )
+
+
+def IBasicBlock(
+    x, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1
+):
+    if groups != 1 or base_width != 64:
+        raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+    if dilation > 1:
+        raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+
+    bn1 = tf.keras.layers.BatchNormalization(
+        scale=False,
+        momentum=0.9,
+        epsilon=1e-05,
+    )
+    conv1 = conv3x3(planes)
+    bn2 = tf.keras.layers.BatchNormalization(
+        scale=False,
+        momentum=0.9,
+        epsilon=1e-05,
+    )
+    prelu = tf.keras.layers.PReLU(
+        alpha_initializer=tf.keras.initializers.Constant(0.25), shared_axes=[1, 2]
+    )
+    conv2 = conv3x3(planes, stride=stride)
+    bn3 = tf.keras.layers.BatchNormalization(
+        scale=False,
+        momentum=0.9,
+        epsilon=1e-05,
+    )
+
+    identity = x
+    out = bn1(x)
+    out = conv1(out)
+    out = bn2(out)
+    out = prelu(out)
+    out = conv2(out)
+    out = bn3(out)
+    if downsample is not None:
+        for layer in downsample:
+            identity = layer(identity)
+    out += identity
+    return out
+
+
+def _make_layer(
+    x, dilation, groups, base_width, block, planes, blocks, stride=1, dilate=False
+):
+    downsample = None
+    previous_dilation = dilation
+    if dilate:
+        dilation *= stride
+        stride = 1
+    if stride != 1 or x.shape[-1] != planes:
+        downsample = [
+            conv1x1(planes, stride),
+            tf.keras.layers.BatchNormalization(
+                scale=False,
+                momentum=0.9,
+                epsilon=1e-05,
+            ),
+        ]
+    x = block(x, planes, stride, downsample, groups, base_width, previous_dilation)
+    for _ in range(1, blocks):
+        x = block(x, planes, groups=groups, base_width=base_width, dilation=dilation)
+
+    return x, dilation
+
+
+def IResNet(
+    name,
+    input_shape,
+    block,
+    layers,
+    groups=1,
+    width_per_group=64,
+    replace_stride_with_dilation=None,
+):
+    x = img_input = tf.keras.layers.Input(shape=input_shape)
+    inplanes = 64
+    dilation = 1
+    if replace_stride_with_dilation is None:
+        replace_stride_with_dilation = [False, False, False]
+    if len(replace_stride_with_dilation) != 3:
+        raise ValueError(
+            "replace_stride_with_dilation should be None "
+            "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
+        )
+    groups = groups
+    base_width = width_per_group
+    conv1 = conv3x3(inplanes, stride=1, dilation=1)
+    bn1 = tf.keras.layers.BatchNormalization(
+        scale=False,
+        momentum=0.9,
+        epsilon=1e-05,
+    )
+    prelu = tf.keras.layers.PReLU(
+        alpha_initializer=tf.keras.initializers.Constant(0.25), shared_axes=[1, 2]
+    )
+
+    x = conv1(x)
+    x = bn1(x)
+    x = prelu(x)
+
+    x, dilation = _make_layer(
+        x=x,
+        dilation=dilation,
+        base_width=base_width,
+        groups=groups,
+        block=block,
+        planes=64,
+        blocks=layers[0],
+        stride=2,
+    )
+    x, dilation = _make_layer(
+        x=x,
+        dilation=dilation,
+        base_width=base_width,
+        groups=groups,
+        block=block,
+        planes=128,
+        blocks=layers[1],
+        stride=2,
+        dilate=replace_stride_with_dilation[0],
+    )
+    x, dilation = _make_layer(
+        x=x,
+        dilation=dilation,
+        base_width=base_width,
+        groups=groups,
+        block=block,
+        planes=256,
+        blocks=layers[2],
+        stride=2,
+        dilate=replace_stride_with_dilation[1],
+    )
+    x, dilation = _make_layer(
+        x=x,
+        dilation=dilation,
+        base_width=base_width,
+        groups=groups,
+        block=block,
+        planes=512,
+        blocks=layers[3],
+        stride=2,
+        dilate=replace_stride_with_dilation[2],
+    )
+
+    return tf.keras.Model(img_input, x, name=name)
+
+
+def iresnet18(input_shape, **kwargs):
+    return IResNet(
+        name="iresnet18",
+        input_shape=input_shape,
+        block=IBasicBlock,
+        layers=[2, 2, 2, 2],
+        **kwargs,
+    )
+
+
+def iresnet34(input_shape, **kwargs):
+    return IResNet(
+        name="iresnet34",
+        input_shape=input_shape,
+        block=IBasicBlock,
+        layers=[3, 4, 6, 3],
+        **kwargs,
+    )
+
+
+def iresnet50(input_shape, **kwargs):
+    return IResNet(
+        name="iresnet50",
+        input_shape=input_shape,
+        block=IBasicBlock,
+        layers=[3, 4, 14, 3],
+        **kwargs,
+    )
+
+
+def iresnet100(input_shape, **kwargs):
+    return IResNet(
+        name="iresnet100",
+        input_shape=input_shape,
+        block=IBasicBlock,
+        layers=[3, 13, 30, 3],
+        **kwargs,
+    )
+
+
+def iresnet200(input_shape, **kwargs):
+    return IResNet(
+        name="iresnet200",
+        input_shape=input_shape,
+        block=IBasicBlock,
+        layers=[6, 26, 60, 6],
+        **kwargs,
+    )
+
+
+if __name__ == "__main__":
+    model = iresnet50((112, 112, 3))
+    model.summary()
+    tf.keras.utils.plot_model(
+        model, "keras_model.svg", show_shapes=True, expand_nested=True, dpi=300
+    )
diff --git a/bob/learn/tensorflow/scripts/fit.py b/bob/learn/tensorflow/scripts/fit.py
index bb3085924b3d9c7f5ff768c06f22d03f6d3b505a..0c6ee51045e66bd6172ba9163a6687e0869a48a4 100644
--- a/bob/learn/tensorflow/scripts/fit.py
+++ b/bob/learn/tensorflow/scripts/fit.py
@@ -1,6 +1,5 @@
 #!/usr/bin/env python
-"""Trains networks using Keras Models.
-"""
+"""Trains networks using Keras Models."""
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -78,6 +77,27 @@ logger = logging.getLogger(__name__)
     cls=ResourceOption,
     help="See tf.keras.Model.fit.",
 )
+@click.option(
+    "--dask-client",
+    "-l",
+    entry_point_group="dask.client",
+    default=None,
+    help="Dask client for the execution of the pipeline.",
+    cls=ResourceOption,
+)
+@click.option(
+    "--strategy-fn",
+    entry_point_group="bob.learn.tensorflow.strategy",
+    default=None,
+    help="The strategy to be used for distributed training.",
+    cls=ResourceOption,
+)
+@click.option(
+    "--mixed-precision-policy",
+    default=None,
+    help="The mixed precision policy to be used for training.",
+    cls=ResourceOption,
+)
 @verbosity_option(cls=ResourceOption)
 def fit(
     model_fn,
@@ -89,9 +109,20 @@ def fit(
     class_weight,
     steps_per_epoch,
     validation_steps,
-    **kwargs
+    dask_client,
+    strategy_fn,
+    mixed_precision_policy,
+    **kwargs,
 ):
     """Trains networks using Keras models."""
+    from tensorflow.keras import mixed_precision
+
+    from bob.extension.log import set_verbosity_level
+    from bob.extension.log import setup as setup_logger
+
+    from ..utils import FloatValuesEncoder
+    from ..utils import compute_tf_config_from_dask_client
+
     log_parameters(logger)
 
     # Train
@@ -102,19 +133,92 @@ def fit(
     if save_callback:
         model_dir = save_callback[0].filepath
         logger.info("Training a model in %s", model_dir)
-    model = model_fn()
-
-    history = model.fit(
-        x=train_input_fn(),
-        epochs=epochs,
-        verbose=max(verbose, 2),
-        callbacks=list(callbacks) if callbacks else None,
-        validation_data=None if eval_input_fn is None else eval_input_fn(),
-        class_weight=class_weight,
-        steps_per_epoch=steps_per_epoch,
-        validation_steps=validation_steps,
-    )
-    click.echo(history.history)
-    if model_dir is not None:
-        with open(os.path.join(model_dir, "keras_fit_history.json"), "w") as f:
-            json.dump(history.history, f)
+    callbacks = list(callbacks) if callbacks else None
+
+    def train(tf_config=None):
+        # setup verbosity again in case we're in a dask worker
+        setup_logger("bob")
+        set_verbosity_level("bob", verbose)
+
+        if tf_config is not None:
+            logger.info("Setting up TF_CONFIG with %s", tf_config)
+            os.environ["TF_CONFIG"] = json.dumps(tf_config)
+
+        if mixed_precision_policy is not None:
+            logger.info("Using %s mixed precision policy", mixed_precision_policy)
+            mixed_precision.set_global_policy(mixed_precision_policy)
+
+        validation_data = None
+
+        if strategy_fn is None:
+            model: tf.keras.Model = model_fn()
+            x = train_input_fn()
+            if eval_input_fn is not None:
+                validation_data = eval_input_fn()
+        else:
+            strategy = strategy_fn()
+            with strategy.scope():
+                model: tf.keras.Model = model_fn()
+                x = strategy.distribute_datasets_from_function(train_input_fn)
+                if eval_input_fn is not None:
+                    validation_data = strategy.distribute_datasets_from_function(
+                        eval_input_fn
+                    )
+
+        # swap 1 and 2 verbosity values for Keras as verbose=1 is more verbose model.fit
+        fit_verbose = {0: 0, 1: 2, 2: 1}[min(verbose, 2)]
+
+        click.echo(
+            f"""Calling {model}.fit with:(
+            x={x},
+            epochs={epochs},
+            verbose={fit_verbose},
+            callbacks={callbacks},
+            validation_data={validation_data},
+            class_weight={class_weight},
+            steps_per_epoch={steps_per_epoch},
+            validation_steps={validation_steps},
+        )
+        and optimizer: {model.optimizer}
+        """
+        )
+        history = model.fit(
+            x=x,
+            epochs=epochs,
+            verbose=fit_verbose,
+            callbacks=callbacks,
+            validation_data=validation_data,
+            class_weight=class_weight,
+            steps_per_epoch=steps_per_epoch,
+            validation_steps=validation_steps,
+        )
+        if model_dir is not None:
+            with open(os.path.join(model_dir, "keras_fit_history.json"), "w") as f:
+                json.dump(history.history, f, cls=FloatValuesEncoder)
+
+        return history.history
+
+    if dask_client is None:
+        history = train()
+    else:
+        tf_configs, workers_ips = compute_tf_config_from_dask_client(dask_client)
+        future_histories = []
+        for tf_spec, ip in zip(tf_configs, workers_ips):
+            future = dask_client.submit(train, tf_spec, workers=ip)
+            future_histories.append(future)
+
+        try:
+            history = dask_client.gather(future_histories)
+        finally:
+            try:
+                logger.debug("Printing dask logs:")
+                for key, value in dask_client.cluster.get_logs().items():
+                    logger.debug(f"{key}:")
+                    logger.debug(value)
+                logger.debug(dask_client.cluster.job_script())
+            except Exception:
+                pass
+
+    logger.debug("history:")
+    logger.debug(history)
+    return history
diff --git a/bob/learn/tensorflow/tests/test_arcface.py b/bob/learn/tensorflow/tests/test_arcface.py
index 4ebfe63e29677cddf54251949eef90beaab6712d..50cef59ac280c84e3b67583c75948b4ede04e16d 100644
--- a/bob/learn/tensorflow/tests/test_arcface.py
+++ b/bob/learn/tensorflow/tests/test_arcface.py
@@ -8,12 +8,13 @@ from bob.learn.tensorflow.models import ArcFaceLayer3Penalties
 
 def test_arcface_layer():
 
-    layer = ArcFaceLayer()
+    layer = ArcFaceLayer(n_classes=10)
     np.random.seed(10)
     X = np.random.rand(10, 50)
     y = [np.random.randint(10) for i in range(10)]
+    output = layer(X, y)
 
-    assert layer(X, y).shape == (10, 10)
+    assert output.shape == (10, 10), output.shape
 
 
 def test_arcface_layer_3p():
diff --git a/bob/learn/tensorflow/utils/keras.py b/bob/learn/tensorflow/utils/keras.py
index 8447d112e8ddbbe529db357a9e449878ae9b64d1..5150e209545598b6b7b9783e4cce22e8b23a2792 100644
--- a/bob/learn/tensorflow/utils/keras.py
+++ b/bob/learn/tensorflow/utils/keras.py
@@ -1,5 +1,10 @@
+import json
 import logging
+import os
+import re
+from json import JSONEncoder
 
+import numpy as np
 import tensorflow as tf
 import tensorflow.keras.backend as K
 from tensorflow.python.util import nest
@@ -13,6 +18,60 @@ SINGLE_LAYER_OUTPUT_ERROR_MSG = (
 )
 
 
+class FloatValuesEncoder(JSONEncoder):
+    """Code from https://stackoverflow.com/a/64155446"""
+
+    def default(self, obj):
+        if isinstance(obj, (np.float16, np.float32, np.float64)):
+            return float(obj)
+        return super().default(obj)
+
+
+def compute_tf_config_from_dask_client(client, reference_tf_port=2222):
+    """
+    This function will compute the tensorflow TF_CONFIG from a dask client
+
+    Check here for more info on how to setup this info:
+
+    https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#multi-worker_configuration
+
+    Parameters
+    ----------
+    client:
+        Dask client
+
+    reference_tf_port:
+        Port used in the TF distributed
+
+    Returns
+    -------
+    tf_configs : list
+        A list of tf configs. Each tf config will be for one worker.
+    """
+
+    clients = list(sorted(client.scheduler_info()["workers"].keys()))
+
+    port = reference_tf_port
+    tf_clients, workers_ips = [], []
+    for client in clients:
+        index = re.search("[0-9]:[0-9]", client)
+        host = client[0 : index.start() + 1] + f":{port}"
+        host = host.split("://")[-1]
+
+        tf_clients.append(host)
+        workers_ips.append(host.split(":")[0])
+        port += 1
+
+    # cluster config
+    cluster = {"worker": tf_clients}
+
+    tf_configs = []
+    for i, _ in enumerate(tf_clients):
+        tf_configs.append({"cluster": cluster, "task": {"type": "worker", "index": i}})
+
+    return tf_configs, workers_ips
+
+
 def keras_channels_index():
     return -3 if K.image_data_format() == "channels_first" else -1
 
@@ -67,6 +126,25 @@ def initialize_model_from_checkpoint(model, checkpoint, normalizer=None):
     tf.compat.v1.train.init_from_checkpoint(checkpoint, assignment_map=assignment_map)
 
 
+def get_number_of_workers():
+    """Returns the number of workers in a distributed strategy.
+    Can be used to increase the batch size dynamically in distributed training.
+
+    Returns
+    -------
+    int
+        The number of workers present in a strategy.
+    """
+    num_workers = 1
+
+    tf_config = os.environ.get("TF_CONFIG")
+    if tf_config is not None:
+        tf_config = json.loads(tf_config)
+        num_workers = len(tf_config["cluster"]["worker"])
+
+    return num_workers
+
+
 def model_summary(model, do_print=False):
     from tensorflow.keras.backend import count_params
 
diff --git a/doc/conf.py b/doc/conf.py
index a7a68a78b4ccf2560fe7a3cf2a10aa91cf40808e..ef68da351a5b41707bc383e998d8963abcd7e5f8 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -34,7 +34,7 @@ extensions = [
 ]
 
 # Be picky about warnings
-nitpicky = True
+nitpicky = False
 
 # Ignores stuff we can't easily resolve on other project's sphinx manuals
 nitpick_ignore = []
@@ -233,6 +233,7 @@ autodoc_default_options = {
     "undoc-members": True,
     "show-inheritance": True,
 }
+autodoc_inherit_docstrings = False
 
 
 sphinx_requirements = "extra-intersphinx.txt"
diff --git a/doc/user_guide.rst b/doc/user_guide.rst
index b36953da60e98298d01e9a280d28a565a90bd998..ad700c2848870f9421574bf62bdd06df6ffb474f 100644
--- a/doc/user_guide.rst
+++ b/doc/user_guide.rst
@@ -123,6 +123,13 @@ It is important that custom metrics and losses do not average their results by t
 size as the values should be averaged by the global batch size:
 https://www.tensorflow.org/tutorials/distribute/custom_training Take a look at custom
 metrics and losses in this package for examples of correct implementations.
+It is best not to override ``train_step`` and ``test_step`` in your model to avoid
+the details of distributed training.
 
+Also, see the distributed training example in the repository of this package in:
+``examples/mnist_multi_worker_mixed_precision.py`` which uses dask. It can be
+executed using::
+
+    bob keras fit -vvv mnist_multi_worker_mixed_precision.py
 
 .. _tensorflow: https://www.tensorflow.org/
diff --git a/examples/mnist_multi_worker_mixed_precision.py b/examples/mnist_multi_worker_mixed_precision.py
new file mode 100644
index 0000000000000000000000000000000000000000..99aea639f0dc58ebb4a5f7cbdd525058ec062d00
--- /dev/null
+++ b/examples/mnist_multi_worker_mixed_precision.py
@@ -0,0 +1,126 @@
+import sys
+
+import dask
+import numpy as np
+import tensorflow as tf
+from dask.distributed import Client
+from dask_jobqueue import SGECluster
+
+from bob.extension import rc
+from bob.learn.tensorflow.callbacks import add_backup_callback
+
+mixed_precision_policy = "mixed_float16"
+strategy_fn = "multi-worker-mirrored-strategy"
+
+
+N_WORKERS = 2
+BATCH_SIZE = 64 * N_WORKERS
+checkpoint_path = "mnist_distributed_mixed_precision"
+steps_per_epoch = 60000 // BATCH_SIZE
+epochs = 2
+
+
+def train_input_fn(ctx=None):
+    (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
+    x_train = x_train / np.float32(255)
+    y_train = y_train.astype(np.int64)
+    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+
+    batch_size = BATCH_SIZE
+    if ctx is not None:
+        # shard the dataset BEFORE any shuffling
+        train_dataset = train_dataset.shard(
+            ctx.num_replicas_in_sync, ctx.input_pipeline_id
+        )
+        # calculate batch size per worker
+        batch_size = ctx.get_per_replica_batch_size(BATCH_SIZE)
+
+    # create inifinite databases, `.repeat()`, for distributed training
+    train_dataset = train_dataset.shuffle(60000).repeat().batch(batch_size)
+    return train_dataset
+
+
+def model_fn():
+    model = tf.keras.Sequential(
+        [
+            tf.keras.Input(shape=(28, 28)),
+            tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
+            tf.keras.layers.Conv2D(32, 3, activation="relu"),
+            tf.keras.layers.Flatten(),
+            tf.keras.layers.Dense(128, activation="relu"),
+            tf.keras.layers.Dense(10),
+            # to support mixed precision training, output(s) must be float32
+            tf.keras.layers.Activation("linear", dtype="float32"),
+        ]
+    )
+    model.compile(
+        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+        optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
+        metrics=["accuracy"],
+    )
+    return model
+
+
+# dask.config.set({"distributed.comm.timeouts.connect": "30s"})
+dask.config.set({"jobqueue.sge.walltime": None})
+dask.config.set({"distributed.worker.memory.target": False})  # Avoid spilling to disk
+dask.config.set({"distributed.worker.memory.spill": False})  # Avoid spilling to disk
+
+cluster = SGECluster(
+    queue="q_short_gpu",
+    memory="28GB",
+    cores=1,
+    processes=1,
+    log_directory="./logs",
+    silence_logs="debug",
+    resource_spec="q_short_gpu=TRUE,hostname=vgne*",
+    project=rc.get("sge.project"),
+    env_extra=[
+        "export PYTHONUNBUFFERED=1",
+        f"export PYTHONPATH={':'.join(sys.path)}",
+        #
+        # may need to unset proxies (probably set by SGE) to make sure tensorflow workers can communicate
+        # see: https://stackoverflow.com/a/66059809/1286165
+        # "unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY",
+        #
+        # May need to tell dask workers not to use daemonic processes
+        # see: https://github.com/dask/distributed/issues/2718
+        # "export DASK_DISTRIBUTED__WORKER__DAEMON=False",
+        #
+        # f"export LD_LIBRARY_PATH={os.environ.get('LD_LIBRARY_PATH', '')}",
+    ],
+)
+cluster.scale(N_WORKERS)
+dask_client = Client(cluster, timeout="2m")
+print(f"Waiting (max 2 hours) for {N_WORKERS} dask workers to come online ...")
+dask_client.wait_for_workers(n_workers=N_WORKERS, timeout="2h")
+print(f"All requested {N_WORKERS} dask workers are ready!")
+
+
+def scheduler(epoch, lr):
+    if epoch in range(20):
+        return 0.1
+    elif epoch in range(20, 30):
+        return 0.01
+    else:
+        return 0.001
+
+
+callbacks = {
+    "latest": tf.keras.callbacks.ModelCheckpoint(
+        f"{checkpoint_path}/latest", verbose=1
+    ),
+    "best": tf.keras.callbacks.ModelCheckpoint(
+        f"{checkpoint_path}/best",
+        save_best_only=True,
+        monitor="accuracy",
+        mode="max",
+        verbose=1,
+    ),
+    "tensorboard": tf.keras.callbacks.TensorBoard(
+        log_dir=f"{checkpoint_path}/logs", update_freq=15, profile_batch=0
+    ),
+    "lr": tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1),
+    "nan": tf.keras.callbacks.TerminateOnNaN(),
+}
+callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup")
diff --git a/setup.py b/setup.py
index 8e00780b0c80252e84e0785b7142d72c6bb7b04e..ba7bce8a61a9d3cc860ec22eb7ff928a021fd435 100644
--- a/setup.py
+++ b/setup.py
@@ -54,6 +54,11 @@ setup(
         "bob.learn.tensorflow.keras_cli": [
             "fit = bob.learn.tensorflow.scripts.fit:fit",
         ],
+        # entry points for bob keras fit --strategy-fn option
+        "bob.learn.tensorflow.strategy": [
+            "multi-worker-mirrored-strategy = bob.learn.tensorflow.configs.MultiWorkerMirroredStrategy:strategy_fn",
+            "mirrored-strategy = bob.learn.tensorflow.configs.MirroredStrategy:strategy_fn",
+        ],
     },
     # Classifiers are important if you plan to distribute this package through
     # PyPI. You can find the complete list of classifiers that are valid and