From 8773ff98caab11c649ccbe661c7f4fd6a8b0d2bc Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Tue, 23 Jun 2020 17:23:45 +0200
Subject: [PATCH] Improved the checkpointing

---
 bob/bio/base/script/vanilla_biometrics.py     | 14 +++++++----
 .../base/script/vanilla_biometrics_ztnorm.py  | 23 +++++++++++++------
 2 files changed, 26 insertions(+), 11 deletions(-)

diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py
index 88054034..4e7e5b01 100644
--- a/bob/bio/base/script/vanilla_biometrics.py
+++ b/bob/bio/base/script/vanilla_biometrics.py
@@ -26,6 +26,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
     dask_get_partition_size,
     FourColumnsScoreWriter,
     CSVScoreWriter,
+    BioAlgorithmLegacy    
 )
 from dask.delayed import Delayed
 import pkg_resources
@@ -213,10 +214,15 @@ def vanilla_biometrics(
         pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp"))
 
     # Check if it's already checkpointed
-    if checkpoint and not isinstance_nested(
-        pipeline.biometric_algorithm,
-        "biometric_algorithm",
-        BioAlgorithmCheckpointWrapper,
+    if checkpoint and (
+        not isinstance_nested(
+            pipeline,
+            "biometric_algorithm",
+            BioAlgorithmCheckpointWrapper,
+        )
+        or not isinstance_nested(
+            pipeline, "biometric_algorithm", BioAlgorithmLegacy
+        )
     ):
         pipeline = checkpoint_vanilla_biometrics(pipeline, output)
 
diff --git a/bob/bio/base/script/vanilla_biometrics_ztnorm.py b/bob/bio/base/script/vanilla_biometrics_ztnorm.py
index ce00590a..bd849a51 100644
--- a/bob/bio/base/script/vanilla_biometrics_ztnorm.py
+++ b/bob/bio/base/script/vanilla_biometrics_ztnorm.py
@@ -29,12 +29,17 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
     dask_get_partition_size,
     FourColumnsScoreWriter,
     CSVScoreWriter,
+    BioAlgorithmLegacy,
 )
 from dask.delayed import Delayed
 from bob.bio.base.utils import get_resource_filename
 from bob.extension.config import load as chain_load
 from bob.pipelines.utils import isinstance_nested
-from .vanilla_biometrics import compute_scores, post_process_scores, load_database_pipeline
+from .vanilla_biometrics import (
+    compute_scores,
+    post_process_scores,
+    load_database_pipeline,
+)
 import copy
 
 logger = logging.getLogger(__name__)
@@ -67,11 +72,14 @@ EPILOG = """\b
     entry_point_group="bob.pipelines.config", cls=ConfigCommand, epilog=EPILOG,
 )
 @click.option(
-    "--pipeline", "-p", required=True, help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
+    "--pipeline",
+    "-p",
+    required=True,
+    help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
 )
 @click.option(
     "--database",
-    "-d",    
+    "-d",
     help="Biometric Database connector (class that implements the methods: `background_model_samples`, `references` and `probes`)",
 )
 @click.option(
@@ -211,10 +219,11 @@ def vanilla_biometrics_ztnorm(
         pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp"))
 
     # Check if it's already checkpointed
-    if checkpoint and not isinstance_nested(
-        pipeline.biometric_algorithm,
-        "biometric_algorithm",
-        BioAlgorithmCheckpointWrapper,
+    if checkpoint and (
+        not isinstance_nested(
+            pipeline, "biometric_algorithm", BioAlgorithmCheckpointWrapper,
+        )
+        and not isinstance_nested(pipeline, "biometric_algorithm", BioAlgorithmLegacy)
     ):
         pipeline = checkpoint_vanilla_biometrics(pipeline, output)
 
-- 
GitLab