From cb83794da843358b822b8d75abfdecc1db7f770c Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Tue, 15 Dec 2020 16:34:05 +0100 Subject: [PATCH] Fixed partition size isse with the ZT-Norm pipeline --- .../vanilla_biometrics/vanilla_biometrics.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/bob/bio/base/pipelines/vanilla_biometrics/vanilla_biometrics.py b/bob/bio/base/pipelines/vanilla_biometrics/vanilla_biometrics.py index 0b9084e7..32fcd62a 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/vanilla_biometrics.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/vanilla_biometrics.py @@ -101,7 +101,9 @@ def execute_vanilla_biometrics( for group in groups: - score_file_name = os.path.join(output, f"scores-{group}") + score_file_name = os.path.join( + output, f"scores-{group}" + ".csv" if write_metadata_scores else "" + ) biometric_references = database.references(group=group) probes = database.probes(group=group) @@ -246,13 +248,20 @@ def execute_vanilla_biometrics_ztnorm( if dask_client is not None and not isinstance_nested( pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper ): + # Scaling up + if dask_n_workers is not None and not isinstance(dask_client, str): + dask_client.cluster.scale(dask_n_workers) + n_objects = max( len(background_model_samples), len(biometric_references), len(probes) ) - pipeline = dask_vanilla_biometrics( - pipeline, - partition_size=dask_get_partition_size(dask_client.cluster, n_objects), - ) + partition_size = None + if not isinstance(dask_client, str): + partition_size = dask_get_partition_size(dask_client.cluster, n_objects) + if dask_partition_size is not None: + partition_size = dask_partition_size + + pipeline = dask_vanilla_biometrics(pipeline, partition_size=partition_size,) logger.info(f"Running vanilla biometrics for group {group}") allow_scoring_with_all_biometric_references = ( -- GitLab