[dask] Make vanilla-pad work properlly with dask

parent 9284e2a7
Pipeline #46386 passed with stage
in 4 minutes and 13 seconds
......@@ -6,6 +6,7 @@ import click
from bob.extension.scripts.click_helper import ConfigCommand
from bob.extension.scripts.click_helper import ResourceOption
from bob.extension.scripts.click_helper import verbosity_option
from bob.pipelines.distributed import dask_get_partition_size
@click.command(
......@@ -67,10 +68,39 @@ from bob.extension.scripts.click_helper import verbosity_option
help="If set, it will checkpoint all steps of the pipeline. Checkpoints will be saved in `--output`.",
cls=ResourceOption,
)
@click.option(
"--dask-partition-size",
"-s",
help="If using Dask, this option defines the size of each dask.bag.partition."
"Use this option if the current heuristic that sets this value doesn't suit your experiment."
"(https://docs.dask.org/en/latest/bag-api.html?highlight=partition_size#dask.bag.from_sequence).",
default=None,
type=click.INT,
cls=ResourceOption,
)
@click.option(
"--dask-n-workers",
"-n",
help="If using Dask, this option defines the number of workers to start your experiment."
"Dask automatically scales up/down the number of workers due to the current load of tasks to be solved."
"Use this option if the current amount of workers set to start an experiment doesn't suit you.",
default=None,
type=click.INT,
cls=ResourceOption,
)
@verbosity_option(cls=ResourceOption)
@click.pass_context
def vanilla_pad(
ctx, pipeline, database, dask_client, groups, output, checkpoint, **kwargs
ctx,
pipeline,
database,
dask_client,
groups,
output,
checkpoint,
dask_partition_size,
dask_n_workers,
**kwargs,
):
"""Runs the simplest PAD pipeline."""
......@@ -84,6 +114,8 @@ def vanilla_pad(
import dask.bag
from bob.extension.scripts.click_helper import log_parameters
from bob.pipelines.distributed.sge import get_resource_requirements
from bob.pipelines.utils import isinstance_nested
from bob.pipelines.wrappers import DaskWrapper
logger = logging.getLogger(__name__)
log_parameters(logger)
......@@ -95,8 +127,33 @@ def vanilla_pad(
["checkpoint"], pipeline, features_dir=output, model_path=output
)
if dask_client is None:
logger.warning("`dask_client` not set. Your pipeline will run locally")
# Fetching samples
fit_samples = database.fit_samples()
total_samples = len(fit_samples)
predict_samples = dict()
for group in groups:
predict_samples[group] = database.predict_samples(group=group)
total_samples += len(predict_samples[group])
# Checking if the pipieline is dask-wrapped
first_step = pipeline[0]
if not isinstance_nested(first_step, "estimator", DaskWrapper):
# Scaling up if necessary
if dask_n_workers is not None and not isinstance(dask_client, str):
dask_client.cluster.scale(dask_n_workers)
# Defining the partition size
partition_size = None
if not isinstance(dask_client, str):
lower_bound = 25 # lower bound of 25 videos per chunk
partition_size = dask_get_partition_size(
dask_client.cluster, total_samples, lower_bound=lower_bound
)
if dask_partition_size is not None:
partition_size = dask_partition_size
pipeline = mario.wrap(["dask"], pipeline, partition_size=partition_size)
# create an experiment info file
with open(os.path.join(output, "Experiment_info.txt"), "wt") as f:
......@@ -107,14 +164,12 @@ def vanilla_pad(
f.write(f"Step {i}: {name}\n{estimator!r}\n")
# train the pipeline
fit_samples = database.fit_samples()
pipeline.fit(fit_samples)
for group in groups:
logger.info(f"Running vanilla biometrics for group {group}")
predict_samples = database.predict_samples(group=group)
result = pipeline.decision_function(predict_samples)
result = pipeline.decision_function(predict_samples[group])
scores_path = os.path.join(output, f"scores-{group}")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment