Commit dfa48e09 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add distributed training support with dask

parent d2ca498b
Pipeline #51543 failed with stage
in 4 minutes and 24 seconds
import logging
import tensorflow as tf
logger = logging.getLogger(__name__)
def strategy_fn():
print("creating strategy")
strategy = tf.distribute.MultiWorkerMirroredStrategy()
print("strategy created")
return strategy
#!/usr/bin/env python
"""Trains networks using Keras Models.
"""
"""Trains networks using Keras Models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -78,6 +77,27 @@ logger = logging.getLogger(__name__)
cls=ResourceOption,
help="See tf.keras.Model.fit.",
)
@click.option(
"--dask-client",
"-l",
entry_point_group="dask.client",
default=None,
help="Dask client for the execution of the pipeline.",
cls=ResourceOption,
)
@click.option(
"--strategy-fn",
entry_point_group="bob.learn.tensorflow.strategy",
default=None,
help="The strategy to be used for distributed training.",
cls=ResourceOption,
)
@click.option(
"--mixed-precision-policy",
default=None,
help="The mixed precision policy to be used for training.",
cls=ResourceOption,
)
@verbosity_option(cls=ResourceOption)
def fit(
model_fn,
......@@ -89,9 +109,20 @@ def fit(
class_weight,
steps_per_epoch,
validation_steps,
**kwargs
dask_client,
strategy_fn,
mixed_precision_policy,
**kwargs,
):
"""Trains networks using Keras models."""
from tensorflow.keras import mixed_precision
from bob.extension.log import set_verbosity_level
from bob.extension.log import setup as setup_logger
from ..utils import FloatValuesEncoder
from ..utils import compute_tf_config_from_dask_client
log_parameters(logger)
# Train
......@@ -102,19 +133,91 @@ def fit(
if save_callback:
model_dir = save_callback[0].filepath
logger.info("Training a model in %s", model_dir)
model = model_fn()
history = model.fit(
x=train_input_fn(),
epochs=epochs,
verbose=max(verbose, 2),
callbacks=list(callbacks) if callbacks else None,
validation_data=None if eval_input_fn is None else eval_input_fn(),
class_weight=class_weight,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
)
click.echo(history.history)
if model_dir is not None:
with open(os.path.join(model_dir, "keras_fit_history.json"), "w") as f:
json.dump(history.history, f)
callbacks = list(callbacks) if callbacks else None
def train(tf_config=None):
# setup verbosity again in case we're in a dask worker
setup_logger("bob")
set_verbosity_level("bob", verbose)
if tf_config is not None:
logger.debug("Setting up TF_CONFIG with %s", tf_config)
os.environ["TF_CONFIG"] = json.dumps(tf_config)
if mixed_precision_policy is not None:
mixed_precision.set_global_policy(mixed_precision_policy)
validation_data = None
if strategy_fn is None:
model = model_fn()
x = train_input_fn()
if eval_input_fn is not None:
validation_data = eval_input_fn()
else:
strategy = strategy_fn()
with strategy.scope():
model = model_fn()
x = strategy.distribute_datasets_from_function(train_input_fn)
if eval_input_fn is not None:
validation_data = strategy.distribute_datasets_from_function(
eval_input_fn
)
# swap 1 and 2 verbosity values for Keras as verbose=1 is more verbose model.fit
fit_verbose = {0: 0, 1: 2, 2: 1}[min(verbose, 2)]
click.echo(
f"""Calling {model}.fit with:(
x={x},
epochs={epochs},
verbose={fit_verbose},
callbacks={callbacks},
validation_data={validation_data},
class_weight={class_weight},
steps_per_epoch={steps_per_epoch},
validation_steps={validation_steps},
)
and optimizer: {model.optimizer}
"""
)
history = model.fit(
x=x,
epochs=epochs,
verbose=fit_verbose,
callbacks=callbacks,
validation_data=validation_data,
class_weight=class_weight,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
)
if model_dir is not None:
with open(os.path.join(model_dir, "keras_fit_history.json"), "w") as f:
json.dump(history.history, f, cls=FloatValuesEncoder)
return history.history
if dask_client is None:
history = train()
else:
tf_configs, workers_ips = compute_tf_config_from_dask_client(dask_client)
future_histories = []
for tf_spec, ip in zip(tf_configs, workers_ips):
future = dask_client.submit(train, tf_spec, workers=ip)
future_histories.append(future)
try:
history = dask_client.gather(future_histories)
finally:
try:
logger.debug("Printing dask logs:")
for key, value in dask_client.cluster.get_logs().items():
logger.debug(f"{key}:")
logger.debug(value)
logger.debug(dask_client.cluster.job_script())
except Exception:
pass
logger.debug("history:")
logger.debug(history)
return history
import json
import logging
import os
import re
from json import JSONEncoder
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.python.util import nest
......@@ -13,6 +18,60 @@ SINGLE_LAYER_OUTPUT_ERROR_MSG = (
)
class FloatValuesEncoder(JSONEncoder):
"""Code from https://stackoverflow.com/a/64155446"""
def default(self, obj):
if isinstance(obj, (np.float16, np.float32, np.float64)):
return float(obj)
return super().default(obj)
def compute_tf_config_from_dask_client(client, reference_tf_port=2222):
"""
This function will compute the tensorflow TF_CONFIG from a dask client
Check here for more info on how to setup this info:
https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#multi-worker_configuration
Parameters
----------
client:
Dask client
reference_tf_port:
Port used in the TF distributed
Returns
-------
tf_configs : list
A list of tf configs. Each tf config will be for one worker.
"""
clients = list(sorted(client.scheduler_info()["workers"].keys()))
port = reference_tf_port
tf_clients, workers_ips = [], []
for client in clients:
index = re.search("[0-9]:[0-9]", client)
host = client[0 : index.start() + 1] + f":{port}"
host = host.split("://")[-1]
tf_clients.append(host)
workers_ips.append(host.split(":")[0])
port += 1
# cluster config
cluster = {"worker": tf_clients}
tf_configs = []
for i, _ in enumerate(tf_clients):
tf_configs.append({"cluster": cluster, "task": {"type": "worker", "index": i}})
return tf_configs, workers_ips
def keras_channels_index():
return -3 if K.image_data_format() == "channels_first" else -1
......@@ -67,6 +126,25 @@ def initialize_model_from_checkpoint(model, checkpoint, normalizer=None):
tf.compat.v1.train.init_from_checkpoint(checkpoint, assignment_map=assignment_map)
def get_number_of_workers():
"""Returns the number of workers in a distributed strategy.
Can be used to increase the batch size dynamically in distributed training.
Returns
-------
int
The number of workers present in a strategy.
"""
num_workers = 1
tf_config = os.environ.get("TF_CONFIG")
if tf_config is not None:
tf_config = json.loads(tf_config)
num_workers = len(tf_config["cluster"]["worker"])
return num_workers
def model_summary(model, do_print=False):
from tensorflow.keras.backend import count_params
......
......@@ -123,6 +123,13 @@ It is important that custom metrics and losses do not average their results by t
size as the values should be averaged by the global batch size:
https://www.tensorflow.org/tutorials/distribute/custom_training Take a look at custom
metrics and losses in this package for examples of correct implementations.
It is best not to override ``train_step`` and ``test_step`` in your model to avoid
the details of distributed training.
Also, see the distributed training example in the repository of this package in:
``examples/mnist_multi_worker_mixed_precision.py`` which uses dask. It can be
executed using::
bob keras fit -vvv mnist_multi_worker_mixed_precision.py
.. _tensorflow: https://www.tensorflow.org/
import sys
import dask
import numpy as np
import tensorflow as tf
from dask.distributed import Client
from dask_jobqueue import SGECluster
from bob.extension import rc
from bob.learn.tensorflow.callbacks import add_backup_callback
mixed_precision_policy = "mixed_float16"
strategy_fn = "multi-worker-mirrored-strategy"
N_WORKERS = 2
BATCH_SIZE = 64 * N_WORKERS
checkpoint_path = "mnist_distributed_mixed_precision"
steps_per_epoch = 60000 // BATCH_SIZE
epochs = 2
def train_input_fn(ctx=None):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
batch_size = BATCH_SIZE
if ctx is not None:
# shard the dataset BEFORE any shuffling
train_dataset = train_dataset.shard(
ctx.num_replicas_in_sync, ctx.input_pipeline_id
)
# calculate batch size per worker
batch_size = ctx.get_per_replica_batch_size(BATCH_SIZE)
# create inifinite databases, `.repeat()`, for distributed training
train_dataset = train_dataset.shuffle(60000).repeat().batch(batch_size)
return train_dataset
def model_fn():
model = tf.keras.Sequential(
[
tf.keras.Input(shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10),
# to support mixed precision training, output(s) must be float32
tf.keras.layers.Activation("linear", dtype="float32"),
]
)
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=["accuracy"],
)
return model
# dask.config.set({"distributed.comm.timeouts.connect": "30s"})
dask.config.set({"jobqueue.sge.walltime": None})
dask.config.set({"distributed.worker.memory.target": False}) # Avoid spilling to disk
dask.config.set({"distributed.worker.memory.spill": False}) # Avoid spilling to disk
cluster = SGECluster(
queue="q_short_gpu",
memory="28GB",
cores=1,
processes=1,
log_directory="./logs",
silence_logs="debug",
resource_spec="q_short_gpu=TRUE,hostname=vgne*",
project=rc.get("sge.project"),
env_extra=[
"export PYTHONUNBUFFERED=1",
f"export PYTHONPATH={':'.join(sys.path)}",
#
# may need to unset proxies (probably set by SGE) to make sure tensorflow workers can communicate
# see: https://stackoverflow.com/a/66059809/1286165
# "unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY",
#
# May need to tell dask workers not to use daemonic processes
# see: https://github.com/dask/distributed/issues/2718
# "export DASK_DISTRIBUTED__WORKER__DAEMON=False",
#
# f"export LD_LIBRARY_PATH={os.environ.get('LD_LIBRARY_PATH', '')}",
],
)
cluster.scale(N_WORKERS)
dask_client = Client(cluster, timeout="2m")
print(f"Waiting (max 2 hours) for {N_WORKERS} dask workers to come online ...")
dask_client.wait_for_workers(n_workers=N_WORKERS, timeout="2h")
print(f"All requested {N_WORKERS} dask workers are ready!")
def scheduler(epoch, lr):
if epoch in range(20):
return 0.1
elif epoch in range(20, 30):
return 0.01
else:
return 0.001
callbacks = {
"latest": tf.keras.callbacks.ModelCheckpoint(
f"{checkpoint_path}/latest", verbose=1
),
"best": tf.keras.callbacks.ModelCheckpoint(
f"{checkpoint_path}/best",
save_best_only=True,
monitor="accuracy",
mode="max",
verbose=1,
),
"tensorboard": tf.keras.callbacks.TensorBoard(
log_dir=f"{checkpoint_path}/logs", update_freq=15, profile_batch=0
),
"lr": tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1),
"nan": tf.keras.callbacks.TerminateOnNaN(),
}
callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup")
......@@ -54,6 +54,10 @@ setup(
"bob.learn.tensorflow.keras_cli": [
"fit = bob.learn.tensorflow.scripts.fit:fit",
],
# entry points for bob keras fit --strategy-fn option
"bob.learn.tensorflow.strategy": [
"multi-worker-mirrored-strategy = bob.learn.tensorflow.configs.MultiWorkerMirroredStrategy:strategy_fn",
],
},
# Classifiers are important if you plan to distribute this package through
# PyPI. You can find the complete list of classifiers that are valid and
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment