diff --git a/bob/pipelines/distributed/__init__.py b/bob/pipelines/distributed/__init__.py index 33a51fa6be31be5ac5eaf64a36dd95087f553216..77c779ea84ec25f8576d68b4832c7fb4549b5936 100644 --- a/bob/pipelines/distributed/__init__.py +++ b/bob/pipelines/distributed/__init__.py @@ -16,3 +16,38 @@ __path__ = extend_path(__path__, __name__) # ) VALID_DASK_CLIENT_STRINGS = ("single-threaded", "sync", "threaded", "processes") + + +def dask_get_partition_size(cluster, n_objects, lower_bound=200): + """ + Heuristics that gives you a number for dask.partition_size. + The heuristics is pretty simple, given the max number of possible workers to be run + in a queue (not the number of current workers running) and a total number objects to be processed do n_objects/n_max_workers: + + Check https://docs.dask.org/en/latest/best-practices.html#avoid-very-large-partitions + for best practices + + Parameters + ---------- + + cluster: :any:`bob.pipelines.distributed.sge.SGEMultipleQueuesCluster` + Cluster of the type :any:`bob.pipelines.distributed.sge.SGEMultipleQueuesCluster` + + n_objects: int + Number of objects to be processed + + lower_bound: int + Minimum partition size. + + """ + from .sge import SGEMultipleQueuesCluster + + if not isinstance(cluster, SGEMultipleQueuesCluster): + return None + + max_jobs = cluster.sge_job_spec["default"]["max_jobs"] + + # Trying to set a lower bound for the + return ( + max(n_objects // max_jobs, lower_bound) if n_objects > max_jobs else n_objects + )