Skip to content
Snippets Groups Projects
Commit a2bc11ba authored by Yannick DAYER's avatar Yannick DAYER
Browse files

[refactor] isort

parent 95ee08d1
No related branches found
No related tags found
1 merge request!290PipelineSimple partitioning fixes
Pipeline #60905 passed
...@@ -18,8 +18,8 @@ from bob.bio.base.pipelines import ( ...@@ -18,8 +18,8 @@ from bob.bio.base.pipelines import (
is_checkpointed, is_checkpointed,
) )
from bob.pipelines.distributed import dask_get_partition_size from bob.pipelines.distributed import dask_get_partition_size
from bob.pipelines.utils import is_estimator_stateless, isinstance_nested
from bob.pipelines.distributed.sge import SGEMultipleQueuesCluster from bob.pipelines.distributed.sge import SGEMultipleQueuesCluster
from bob.pipelines.utils import is_estimator_stateless, isinstance_nested
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -185,11 +185,15 @@ def execute_pipeline_simple( ...@@ -185,11 +185,15 @@ def execute_pipeline_simple(
if dask_partition_size is not None: if dask_partition_size is not None:
# Create partitions of the same defined size for each Set # Create partitions of the same defined size for each Set
n_objects = max( n_objects = max(
len(background_model_samples), len(biometric_references), len(probes) len(background_model_samples),
len(biometric_references),
len(probes),
) )
partition_size = None partition_size = None
if not isinstance(dask_client, str): if not isinstance(dask_client, str):
partition_size = dask_get_partition_size(dask_client.cluster, n_objects, dask_partition_size) partition_size = dask_get_partition_size(
dask_client.cluster, n_objects, dask_partition_size
)
logger.debug("Splitting data with fixed size partitions.") logger.debug("Splitting data with fixed size partitions.")
pipeline = dask_pipeline_simple( pipeline = dask_pipeline_simple(
pipeline, pipeline,
...@@ -206,11 +210,15 @@ def execute_pipeline_simple( ...@@ -206,11 +210,15 @@ def execute_pipeline_simple(
# Split in max_jobs partitions or revert to the default behavior of # Split in max_jobs partitions or revert to the default behavior of
# dask.Bag from_sequence: partition_size = 100 # dask.Bag from_sequence: partition_size = 100
n_jobs = None n_jobs = None
if not isinstance(dask_client, str) and isinstance(dask_client.cluster, SGEMultipleQueuesCluster): if not isinstance(dask_client, str) and isinstance(
dask_client.cluster, SGEMultipleQueuesCluster
):
logger.debug( logger.debug(
"Splitting data according to the number of available workers." "Splitting data according to the number of available workers."
) )
n_jobs = dask_client.cluster.sge_job_spec["default"]["max_jobs"] n_jobs = dask_client.cluster.sge_job_spec["default"][
"max_jobs"
]
logger.debug(f"{n_jobs} partitions will be created.") logger.debug(f"{n_jobs} partitions will be created.")
pipeline = dask_pipeline_simple(pipeline, npartitions=n_jobs) pipeline = dask_pipeline_simple(pipeline, npartitions=n_jobs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment