From 280495cb76ccb5a91b80e6003c939509368b4650 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 16 Jul 2020 14:59:17 +0200
Subject: [PATCH] [script.significance] Implement checkpointing for patch
 performance calculations

---
 bob/ip/binseg/script/significance.py | 48 +++++++++++++++++++++++++++-
 1 file changed, 47 insertions(+), 1 deletion(-)

diff --git a/bob/ip/binseg/script/significance.py b/bob/ip/binseg/script/significance.py
index 1dd98283..777703fc 100755
--- a/bob/ip/binseg/script/significance.py
+++ b/bob/ip/binseg/script/significance.py
@@ -39,6 +39,7 @@ def _eval_patches(
     outdir,
     figure,
     nproc,
+    checkpointdir,
 ):
     """Calculates the patch performances on a dataset
 
@@ -99,6 +100,10 @@ def _eval_patches(
         ``1`` avoids completely the use of multiprocessing and runs all chores
         in the current processing context.
 
+    checkpointdir : str
+        If set to a string (instead of ``None``), then stores a cached version
+        of the patch performances on disk, for a particular system.
+
 
     Returns
     =======
@@ -128,6 +133,27 @@ def _eval_patches(
 
     """
 
+    if checkpointdir is not None:
+        chkpt_fname = os.path.join(checkpointdir,
+                f"{system_name}-{evaluate}-{threshold}-" \
+                f"{size[0]}x{size[1]}+{stride[0]}x{stride[1]}-{figure}.pkl.gz"
+                )
+        os.makedirs(os.path.dirname(chkpt_fname), exist_ok=True)
+        if os.path.exists(chkpt_fname):
+            logger.info(f"Loading checkpoint from {chkpt_fname}...")
+            # loads and returns checkpoint from file
+            try:
+                with __import__('gzip').GzipFile(chkpt_fname, "r") as f:
+                    return __import__('pickle').load(f)
+            except EOFError as e:
+                logger.warning(f"Could not load patch performance from " \
+                        f"{chkpt_fname}: {e}. Calculating...")
+        else:
+            logger.debug(f"Checkpoint not available at {chkpt_fname}. " \
+                    f"Calculating...")
+    else:
+        chkpt_fname = None
+
     if not isinstance(threshold, float):
 
         assert threshold in dataset, f"No dataset named '{threshold}'"
@@ -147,7 +173,7 @@ def _eval_patches(
         f"'{system_name}' using windows of size {size} and stride {stride}"
     )
 
-    return patch_performances(
+    retval = patch_performances(
         dataset,
         evaluate,
         preddir,
@@ -159,6 +185,14 @@ def _eval_patches(
         outdir,
     )
 
+    # cache patch performance for later use, if necessary
+    if chkpt_fname is not None:
+        logger.debug(f"Storing checkpoint at {chkpt_fname}...")
+        with __import__('gzip').GzipFile(chkpt_fname, "w") as f:
+            __import__('pickle').dump(retval, f)
+
+    return retval
+
 
 def _eval_differences(perf1, perf2, evaluate, dataset, size, stride, outdir,
         figure, nproc):
@@ -405,6 +439,15 @@ def _eval_differences(perf1, perf2, evaluate, dataset, size, stride, outdir,
     required=True,
     cls=ResourceOption,
 )
+@click.option(
+    "--checkpoint-folder",
+    "-k",
+    help="Path where to store checkpointed versions of patch performances",
+    required=False,
+    type=click.Path(),
+    show_default=True,
+    cls=ResourceOption,
+)
 @verbosity_option(cls=ResourceOption)
 def significance(
     names,
@@ -420,6 +463,7 @@ def significance(
     remove_outliers,
     remove_zeros,
     parallel,
+    checkpoint_folder,
     **kwargs,
 ):
     """Evaluates how significantly different are two models on the same dataset
@@ -446,6 +490,7 @@ def significance(
         else os.path.join(output_folder, names[0])),
         figure,
         parallel,
+        checkpoint_folder,
     )
 
     perf2 = _eval_patches(
@@ -462,6 +507,7 @@ def significance(
             else os.path.join(output_folder, names[1])),
         figure,
         parallel,
+        checkpoint_folder,
     )
 
     perf_diff = _eval_differences(
-- 
GitLab