diff --git a/bob/learn/tensorflow/configs/MirroredStrategy.py b/bob/learn/tensorflow/configs/MirroredStrategy.py new file mode 100644 index 0000000000000000000000000000000000000000..5273ac3465a70d7c7b7371d7637b19c4428ec536 --- /dev/null +++ b/bob/learn/tensorflow/configs/MirroredStrategy.py @@ -0,0 +1,9 @@ +import tensorflow as tf + + +def strategy_fn(): + print("Creating MirroredStrategy strategy.") + strategy = tf.distribute.MirroredStrategy() + print("MirroredStrategy strategy created.") + print("Number of devices: {}".format(strategy.num_replicas_in_sync)) + return strategy diff --git a/setup.py b/setup.py index 8cc047c82bbdb50fba8e2cf1ee37c77368c205e8..ba7bce8a61a9d3cc860ec22eb7ff928a021fd435 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,7 @@ setup( # entry points for bob keras fit --strategy-fn option "bob.learn.tensorflow.strategy": [ "multi-worker-mirrored-strategy = bob.learn.tensorflow.configs.MultiWorkerMirroredStrategy:strategy_fn", + "mirrored-strategy = bob.learn.tensorflow.configs.MirroredStrategy:strategy_fn", ], }, # Classifiers are important if you plan to distribute this package through