Skip to content
Snippets Groups Projects
Commit e32c1ca8 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[script.significance] Implement checkpointing for patch difference evaluation

parent c392a74f
No related branches found
No related tags found
No related merge requests found
Pipeline #41313 failed
......@@ -194,15 +194,19 @@ def _eval_patches(
return retval
def _eval_differences(perf1, perf2, evaluate, dataset, size, stride, outdir,
figure, nproc):
def _eval_differences(names, perfs, evaluate, dataset, size, stride, outdir,
figure, nproc, checkpointdir):
"""Evaluate differences in the performance patches between two systems
Parameters
----------
perf1, perf2 : dict
A dictionary as returned by :py:func:`_eval_patches`
names : :py:class:`tuple` of :py:class:`str`
Names of the first and second systems
perfs : :py:class:`tuple` of :py:class:`dict`
Dictionaries for the patch performances of each system, as returned by
:py:func:`_eval_patches`
evaluate : str
Name of the dataset key to use from ``dataset`` to evaluate (typically,
......@@ -235,6 +239,11 @@ def _eval_differences(perf1, perf2, evaluate, dataset, size, stride, outdir,
``1`` avoids completely the use of multiprocessing and runs all chores
in the current processing context.
checkpointdir : str
If set to a string (instead of ``None``), then stores a cached version
of the patch performances on disk, for a particular difference between
systems.
Returns
-------
......@@ -246,7 +255,28 @@ def _eval_differences(perf1, perf2, evaluate, dataset, size, stride, outdir,
"""
perf_diff = dict([(k, perf1[k]["df"].copy()) for k in perf1])
if checkpointdir is not None:
chkpt_fname = os.path.join(checkpointdir,
f"{names[0]}-{names[1]}-{evaluate}-" \
f"{size[0]}x{size[1]}+{stride[0]}x{stride[1]}-{figure}.pkl.gz"
)
os.makedirs(os.path.dirname(chkpt_fname), exist_ok=True)
if os.path.exists(chkpt_fname):
logger.info(f"Loading checkpoint from {chkpt_fname}...")
# loads and returns checkpoint from file
try:
with __import__('gzip').GzipFile(chkpt_fname, "r") as f:
return __import__('pickle').load(f)
except EOFError as e:
logger.warning(f"Could not load patch performance from " \
f"{chkpt_fname}: {e}. Calculating...")
else:
logger.debug(f"Checkpoint not available at {chkpt_fname}. " \
f"Calculating...")
else:
chkpt_fname = None
perf_diff = dict([(k, perfs[0][k]["df"].copy()) for k in perfs[0]])
# we can subtract these
to_subtract = (
......@@ -260,9 +290,9 @@ def _eval_differences(perf1, perf2, evaluate, dataset, size, stride, outdir,
for k in perf_diff:
for col in to_subtract:
perf_diff[k][col] -= perf2[k]["df"][col]
perf_diff[k][col] -= perfs[1][k]["df"][col]
return visual_performances(
retval = visual_performances(
dataset,
evaluate,
perf_diff,
......@@ -273,6 +303,14 @@ def _eval_differences(perf1, perf2, evaluate, dataset, size, stride, outdir,
outdir,
)
# cache patch performance for later use, if necessary
if chkpt_fname is not None:
logger.debug(f"Storing checkpoint at {chkpt_fname}...")
with __import__('gzip').GzipFile(chkpt_fname, "w") as f:
__import__('pickle').dump(retval, f)
return retval
@click.command(
entry_point_group="bob.ip.binseg.config",
......@@ -522,6 +560,7 @@ def significance(
else os.path.join(output_folder, "diff")),
figure,
parallel,
checkpoint_folder,
)
# loads all figures for the given threshold
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment