Skip to content
Snippets Groups Projects
Commit 280495cb authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[script.significance] Implement checkpointing for patch performance calculations

parent f2e973ec
No related branches found
No related tags found
No related merge requests found
Pipeline #41254 passed
...@@ -39,6 +39,7 @@ def _eval_patches( ...@@ -39,6 +39,7 @@ def _eval_patches(
outdir, outdir,
figure, figure,
nproc, nproc,
checkpointdir,
): ):
"""Calculates the patch performances on a dataset """Calculates the patch performances on a dataset
...@@ -99,6 +100,10 @@ def _eval_patches( ...@@ -99,6 +100,10 @@ def _eval_patches(
``1`` avoids completely the use of multiprocessing and runs all chores ``1`` avoids completely the use of multiprocessing and runs all chores
in the current processing context. in the current processing context.
checkpointdir : str
If set to a string (instead of ``None``), then stores a cached version
of the patch performances on disk, for a particular system.
Returns Returns
======= =======
...@@ -128,6 +133,27 @@ def _eval_patches( ...@@ -128,6 +133,27 @@ def _eval_patches(
""" """
if checkpointdir is not None:
chkpt_fname = os.path.join(checkpointdir,
f"{system_name}-{evaluate}-{threshold}-" \
f"{size[0]}x{size[1]}+{stride[0]}x{stride[1]}-{figure}.pkl.gz"
)
os.makedirs(os.path.dirname(chkpt_fname), exist_ok=True)
if os.path.exists(chkpt_fname):
logger.info(f"Loading checkpoint from {chkpt_fname}...")
# loads and returns checkpoint from file
try:
with __import__('gzip').GzipFile(chkpt_fname, "r") as f:
return __import__('pickle').load(f)
except EOFError as e:
logger.warning(f"Could not load patch performance from " \
f"{chkpt_fname}: {e}. Calculating...")
else:
logger.debug(f"Checkpoint not available at {chkpt_fname}. " \
f"Calculating...")
else:
chkpt_fname = None
if not isinstance(threshold, float): if not isinstance(threshold, float):
assert threshold in dataset, f"No dataset named '{threshold}'" assert threshold in dataset, f"No dataset named '{threshold}'"
...@@ -147,7 +173,7 @@ def _eval_patches( ...@@ -147,7 +173,7 @@ def _eval_patches(
f"'{system_name}' using windows of size {size} and stride {stride}" f"'{system_name}' using windows of size {size} and stride {stride}"
) )
return patch_performances( retval = patch_performances(
dataset, dataset,
evaluate, evaluate,
preddir, preddir,
...@@ -159,6 +185,14 @@ def _eval_patches( ...@@ -159,6 +185,14 @@ def _eval_patches(
outdir, outdir,
) )
# cache patch performance for later use, if necessary
if chkpt_fname is not None:
logger.debug(f"Storing checkpoint at {chkpt_fname}...")
with __import__('gzip').GzipFile(chkpt_fname, "w") as f:
__import__('pickle').dump(retval, f)
return retval
def _eval_differences(perf1, perf2, evaluate, dataset, size, stride, outdir, def _eval_differences(perf1, perf2, evaluate, dataset, size, stride, outdir,
figure, nproc): figure, nproc):
...@@ -405,6 +439,15 @@ def _eval_differences(perf1, perf2, evaluate, dataset, size, stride, outdir, ...@@ -405,6 +439,15 @@ def _eval_differences(perf1, perf2, evaluate, dataset, size, stride, outdir,
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option(
"--checkpoint-folder",
"-k",
help="Path where to store checkpointed versions of patch performances",
required=False,
type=click.Path(),
show_default=True,
cls=ResourceOption,
)
@verbosity_option(cls=ResourceOption) @verbosity_option(cls=ResourceOption)
def significance( def significance(
names, names,
...@@ -420,6 +463,7 @@ def significance( ...@@ -420,6 +463,7 @@ def significance(
remove_outliers, remove_outliers,
remove_zeros, remove_zeros,
parallel, parallel,
checkpoint_folder,
**kwargs, **kwargs,
): ):
"""Evaluates how significantly different are two models on the same dataset """Evaluates how significantly different are two models on the same dataset
...@@ -446,6 +490,7 @@ def significance( ...@@ -446,6 +490,7 @@ def significance(
else os.path.join(output_folder, names[0])), else os.path.join(output_folder, names[0])),
figure, figure,
parallel, parallel,
checkpoint_folder,
) )
perf2 = _eval_patches( perf2 = _eval_patches(
...@@ -462,6 +507,7 @@ def significance( ...@@ -462,6 +507,7 @@ def significance(
else os.path.join(output_folder, names[1])), else os.path.join(output_folder, names[1])),
figure, figure,
parallel, parallel,
checkpoint_folder,
) )
perf_diff = _eval_differences( perf_diff = _eval_differences(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment