From 8f1b8d379c8824cd9bfc966cb7fcf3fde19496d7 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Wed, 30 Jun 2021 14:55:41 +0200 Subject: [PATCH] Add a mirrored strategy --- bob/learn/tensorflow/configs/MirroredStrategy.py | 9 +++++++++ setup.py | 1 + 2 files changed, 10 insertions(+) create mode 100644 bob/learn/tensorflow/configs/MirroredStrategy.py diff --git a/bob/learn/tensorflow/configs/MirroredStrategy.py b/bob/learn/tensorflow/configs/MirroredStrategy.py new file mode 100644 index 00000000..5273ac34 --- /dev/null +++ b/bob/learn/tensorflow/configs/MirroredStrategy.py @@ -0,0 +1,9 @@ +import tensorflow as tf + + +def strategy_fn(): + print("Creating MirroredStrategy strategy.") + strategy = tf.distribute.MirroredStrategy() + print("MirroredStrategy strategy created.") + print("Number of devices: {}".format(strategy.num_replicas_in_sync)) + return strategy diff --git a/setup.py b/setup.py index 8cc047c8..ba7bce8a 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,7 @@ setup( # entry points for bob keras fit --strategy-fn option "bob.learn.tensorflow.strategy": [ "multi-worker-mirrored-strategy = bob.learn.tensorflow.configs.MultiWorkerMirroredStrategy:strategy_fn", + "mirrored-strategy = bob.learn.tensorflow.configs.MirroredStrategy:strategy_fn", ], }, # Classifiers are important if you plan to distribute this package through -- GitLab