From 8f1b8d379c8824cd9bfc966cb7fcf3fde19496d7 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Wed, 30 Jun 2021 14:55:41 +0200
Subject: [PATCH] Add a mirrored strategy

---
 bob/learn/tensorflow/configs/MirroredStrategy.py | 9 +++++++++
 setup.py                                         | 1 +
 2 files changed, 10 insertions(+)
 create mode 100644 bob/learn/tensorflow/configs/MirroredStrategy.py

diff --git a/bob/learn/tensorflow/configs/MirroredStrategy.py b/bob/learn/tensorflow/configs/MirroredStrategy.py
new file mode 100644
index 00000000..5273ac34
--- /dev/null
+++ b/bob/learn/tensorflow/configs/MirroredStrategy.py
@@ -0,0 +1,9 @@
+import tensorflow as tf
+
+
+def strategy_fn():
+    print("Creating MirroredStrategy strategy.")
+    strategy = tf.distribute.MirroredStrategy()
+    print("MirroredStrategy strategy created.")
+    print("Number of devices: {}".format(strategy.num_replicas_in_sync))
+    return strategy
diff --git a/setup.py b/setup.py
index 8cc047c8..ba7bce8a 100644
--- a/setup.py
+++ b/setup.py
@@ -57,6 +57,7 @@ setup(
         # entry points for bob keras fit --strategy-fn option
         "bob.learn.tensorflow.strategy": [
             "multi-worker-mirrored-strategy = bob.learn.tensorflow.configs.MultiWorkerMirroredStrategy:strategy_fn",
+            "mirrored-strategy = bob.learn.tensorflow.configs.MirroredStrategy:strategy_fn",
         ],
     },
     # Classifiers are important if you plan to distribute this package through
-- 
GitLab