From cb83794da843358b822b8d75abfdecc1db7f770c Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Tue, 15 Dec 2020 16:34:05 +0100
Subject: [PATCH] Fixed partition size isse with the ZT-Norm pipeline

---
 .../vanilla_biometrics/vanilla_biometrics.py  | 19 ++++++++++++++-----
 1 file changed, 14 insertions(+), 5 deletions(-)

diff --git a/bob/bio/base/pipelines/vanilla_biometrics/vanilla_biometrics.py b/bob/bio/base/pipelines/vanilla_biometrics/vanilla_biometrics.py
index 0b9084e7..32fcd62a 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/vanilla_biometrics.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/vanilla_biometrics.py
@@ -101,7 +101,9 @@ def execute_vanilla_biometrics(
 
     for group in groups:
 
-        score_file_name = os.path.join(output, f"scores-{group}")
+        score_file_name = os.path.join(
+            output, f"scores-{group}" + ".csv" if write_metadata_scores else ""
+        )
         biometric_references = database.references(group=group)
         probes = database.probes(group=group)
 
@@ -246,13 +248,20 @@ def execute_vanilla_biometrics_ztnorm(
         if dask_client is not None and not isinstance_nested(
             pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper
         ):
+            # Scaling up
+            if dask_n_workers is not None and not isinstance(dask_client, str):
+                dask_client.cluster.scale(dask_n_workers)
+
             n_objects = max(
                 len(background_model_samples), len(biometric_references), len(probes)
             )
-            pipeline = dask_vanilla_biometrics(
-                pipeline,
-                partition_size=dask_get_partition_size(dask_client.cluster, n_objects),
-            )
+            partition_size = None
+            if not isinstance(dask_client, str):
+                partition_size = dask_get_partition_size(dask_client.cluster, n_objects)
+            if dask_partition_size is not None:
+                partition_size = dask_partition_size
+
+            pipeline = dask_vanilla_biometrics(pipeline, partition_size=partition_size,)
 
         logger.info(f"Running vanilla biometrics for group {group}")
         allow_scoring_with_all_biometric_references = (
-- 
GitLab