From 3797c15f70451ab1d5876ca84d92096fd954f521 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 2 Jul 2020 08:59:41 +0200
Subject: [PATCH] Implemented a mechanism to get the partition_size

---
 bob/bio/base/script/vanilla_biometrics.py        |  6 +++---
 bob/bio/base/script/vanilla_biometrics_ztnorm.py | 11 ++++++-----
 2 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py
index 4e7e5b01..e2d5341e 100644
--- a/bob/bio/base/script/vanilla_biometrics.py
+++ b/bob/bio/base/script/vanilla_biometrics.py
@@ -220,7 +220,7 @@ def vanilla_biometrics(
             "biometric_algorithm",
             BioAlgorithmCheckpointWrapper,
         )
-        or not isinstance_nested(
+        and not isinstance_nested(
             pipeline, "biometric_algorithm", BioAlgorithmLegacy
         )
     ):
@@ -236,8 +236,8 @@ def vanilla_biometrics(
         if dask_client is not None and not isinstance_nested(
             pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper
         ):
-            n_objects = (
-                len(background_model_samples) + len(biometric_references) + len(probes)
+            n_objects = max(
+                len(background_model_samples), len(biometric_references), len(probes)
             )
             pipeline = dask_vanilla_biometrics(
                 pipeline,
diff --git a/bob/bio/base/script/vanilla_biometrics_ztnorm.py b/bob/bio/base/script/vanilla_biometrics_ztnorm.py
index bd849a51..4150e232 100644
--- a/bob/bio/base/script/vanilla_biometrics_ztnorm.py
+++ b/bob/bio/base/script/vanilla_biometrics_ztnorm.py
@@ -229,9 +229,10 @@ def vanilla_biometrics_ztnorm(
 
     # Patching the pipeline in case of ZNorm and checkpointing it
     pipeline = ZTNormPipeline(pipeline)
-    pipeline.ztnorm_solver = ZTNormCheckpointWrapper(
-        pipeline.ztnorm_solver, os.path.join(output, "normed-scores")
-    )
+    if checkpoint:
+        pipeline.ztnorm_solver = ZTNormCheckpointWrapper(
+            pipeline.ztnorm_solver, os.path.join(output, "normed-scores")
+        )
 
     background_model_samples = database.background_model_samples()
     zprobes = database.zprobes(proportion=ztnorm_cohort_proportion)
@@ -246,8 +247,8 @@ def vanilla_biometrics_ztnorm(
         if dask_client is not None and not isinstance_nested(
             pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper
         ):
-            n_objects = (
-                len(background_model_samples) + len(biometric_references) + len(probes)
+            n_objects = max(
+                len(background_model_samples), len(biometric_references), len(probes)
             )
             pipeline = dask_vanilla_biometrics(
                 pipeline,
-- 
GitLab