From 45b7f66c73fbc3db6089d87a6a243c18479e58b4 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Mon, 27 Jul 2020 13:29:05 +0200
Subject: [PATCH] Created the function is_checkpointed

---
 .../pipelines/vanilla_biometrics/__init__.py    |  1 +
 bob/bio/base/script/vanilla_biometrics.py       | 17 ++++++-----------
 .../base/script/vanilla_biometrics_ztnorm.py    |  8 ++------
 3 files changed, 9 insertions(+), 17 deletions(-)

diff --git a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py
index 58a34e16..6388efbd 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py
@@ -8,6 +8,7 @@ from .wrappers import (
     dask_vanilla_biometrics,
     checkpoint_vanilla_biometrics,
     dask_get_partition_size,
+    is_checkpointed
 )
 
 from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper, ZTNormCheckpointWrapper
diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py
index e2d5341e..b71009cc 100644
--- a/bob/bio/base/script/vanilla_biometrics.py
+++ b/bob/bio/base/script/vanilla_biometrics.py
@@ -26,7 +26,8 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
     dask_get_partition_size,
     FourColumnsScoreWriter,
     CSVScoreWriter,
-    BioAlgorithmLegacy    
+    BioAlgorithmLegacy,
+    is_checkpointed
 )
 from dask.delayed import Delayed
 import pkg_resources
@@ -214,19 +215,11 @@ def vanilla_biometrics(
         pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp"))
 
     # Check if it's already checkpointed
-    if checkpoint and (
-        not isinstance_nested(
-            pipeline,
-            "biometric_algorithm",
-            BioAlgorithmCheckpointWrapper,
-        )
-        and not isinstance_nested(
-            pipeline, "biometric_algorithm", BioAlgorithmLegacy
-        )
-    ):
+    if checkpoint and not is_checkpointed(pipeline):
         pipeline = checkpoint_vanilla_biometrics(pipeline, output)
 
     background_model_samples = database.background_model_samples()
+
     for group in groups:
 
         score_file_name = os.path.join(output, f"scores-{group}")
@@ -236,6 +229,7 @@ def vanilla_biometrics(
         if dask_client is not None and not isinstance_nested(
             pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper
         ):
+
             n_objects = max(
                 len(background_model_samples), len(biometric_references), len(probes)
             )
@@ -258,6 +252,7 @@ def vanilla_biometrics(
             allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
         )
 
+
         post_processed_scores = post_process_scores(pipeline, result, score_file_name)
         _ = compute_scores(post_processed_scores, dask_client)
 
diff --git a/bob/bio/base/script/vanilla_biometrics_ztnorm.py b/bob/bio/base/script/vanilla_biometrics_ztnorm.py
index 4150e232..0cc11b39 100644
--- a/bob/bio/base/script/vanilla_biometrics_ztnorm.py
+++ b/bob/bio/base/script/vanilla_biometrics_ztnorm.py
@@ -30,6 +30,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
     FourColumnsScoreWriter,
     CSVScoreWriter,
     BioAlgorithmLegacy,
+    is_checkpointed    
 )
 from dask.delayed import Delayed
 from bob.bio.base.utils import get_resource_filename
@@ -219,12 +220,7 @@ def vanilla_biometrics_ztnorm(
         pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp"))
 
     # Check if it's already checkpointed
-    if checkpoint and (
-        not isinstance_nested(
-            pipeline, "biometric_algorithm", BioAlgorithmCheckpointWrapper,
-        )
-        and not isinstance_nested(pipeline, "biometric_algorithm", BioAlgorithmLegacy)
-    ):
+    if checkpoint and not is_checkpointed(pipeline):
         pipeline = checkpoint_vanilla_biometrics(pipeline, output)
 
     # Patching the pipeline in case of ZNorm and checkpointing it
-- 
GitLab