diff --git a/bob/bio/base/pipelines/entry_points.py b/bob/bio/base/pipelines/entry_points.py index ea2f6d84c204ff6da0ad821942f7d579befc374a..b7d2efb4debcbeab8f1a1bb0fecf20301072c3fd 100644 --- a/bob/bio/base/pipelines/entry_points.py +++ b/bob/bio/base/pipelines/entry_points.py @@ -18,8 +18,8 @@ from bob.bio.base.pipelines import ( is_checkpointed, ) from bob.pipelines.distributed import dask_get_partition_size -from bob.pipelines.utils import is_estimator_stateless, isinstance_nested from bob.pipelines.distributed.sge import SGEMultipleQueuesCluster +from bob.pipelines.utils import is_estimator_stateless, isinstance_nested logger = logging.getLogger(__name__) @@ -185,11 +185,15 @@ def execute_pipeline_simple( if dask_partition_size is not None: # Create partitions of the same defined size for each Set n_objects = max( - len(background_model_samples), len(biometric_references), len(probes) + len(background_model_samples), + len(biometric_references), + len(probes), ) partition_size = None if not isinstance(dask_client, str): - partition_size = dask_get_partition_size(dask_client.cluster, n_objects, dask_partition_size) + partition_size = dask_get_partition_size( + dask_client.cluster, n_objects, dask_partition_size + ) logger.debug("Splitting data with fixed size partitions.") pipeline = dask_pipeline_simple( pipeline, @@ -206,11 +210,15 @@ def execute_pipeline_simple( # Split in max_jobs partitions or revert to the default behavior of # dask.Bag from_sequence: partition_size = 100 n_jobs = None - if not isinstance(dask_client, str) and isinstance(dask_client.cluster, SGEMultipleQueuesCluster): + if not isinstance(dask_client, str) and isinstance( + dask_client.cluster, SGEMultipleQueuesCluster + ): logger.debug( "Splitting data according to the number of available workers." ) - n_jobs = dask_client.cluster.sge_job_spec["default"]["max_jobs"] + n_jobs = dask_client.cluster.sge_job_spec["default"][ + "max_jobs" + ] logger.debug(f"{n_jobs} partitions will be created.") pipeline = dask_pipeline_simple(pipeline, npartitions=n_jobs)