diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py index 5fea88b1ffc244b4a6a55d022cbb666332de7e2b..091792acd02bd5629d94d1ccf64bed4bdcedaec6 100644 --- a/bob/ip/binseg/script/binseg.py +++ b/bob/ip/binseg/script/binseg.py @@ -4,6 +4,7 @@ """The main entry for bob ip binseg (click-based) scripts.""" import os +import re import sys import time import tempfile @@ -20,6 +21,30 @@ import logging logger = logging.getLogger(__name__) +def escape_name(v): + """Escapes a name so it contains filesystem friendly characters only + + This function escapes every character that's not a letter, ``_``, ``-``, + ``.`` or space with an ``-``. + + + Parameters + ========== + + v : str + String to be escaped + + + Returns + ======= + + s : str + Escaped string + + """ + return re.sub(r'[^\w\-_\. ]', '-', v) + + def save_sh_command(destfile): """Records command-line to reproduce this experiment diff --git a/bob/ip/binseg/script/evaluate.py b/bob/ip/binseg/script/evaluate.py index 4a8a0c2af913453b83882cd2c5d38538caa27e13..1cfc2bf1e33245891d39465b015dcf27cc9d8447 100644 --- a/bob/ip/binseg/script/evaluate.py +++ b/bob/ip/binseg/script/evaluate.py @@ -45,6 +45,38 @@ def _validate_threshold(t, dataset): return t +def _get_folder(folder, name): + """Guesses the prediction folder to use based on the dataset name + + This function will look for ``folder/name`` if it exists, and + return this. Otherwise defaults to ``folder``. + + + Parameters + ========== + + folder : str + Path to the root of the predictions folder + + name : str + The name of the dataset for which we are trying to find the predictions + folder + + + Returns + ======= + + path : str + The best path to use as the root of the predictions folder for this + dataset. + + """ + candidate = os.path.join(folder, name) + if os.path.exists(candidate): + return candidate + return folder + + @click.command( entry_point_group="bob.ip.binseg.config", cls=ConfigCommand, @@ -170,13 +202,17 @@ def evaluate( second_annotator = {} elif not isinstance(second_annotator, dict): second_annotator = {"test": second_annotator} - #else, second_annotator must be a dict + # else, second_annotator must be a dict if isinstance(threshold, str): # first run evaluation for reference dataset, do not save overlays logger.info(f"Evaluating threshold on '{threshold}' set") - threshold = run(dataset[threshold], threshold, predictions_folder, - steps=steps) + threshold = run( + dataset[threshold], + threshold, + _get_folder(predictions_folder, threshold), + steps=steps, + ) logger.info(f"Set --threshold={threshold:.5f}") # now run with the @@ -185,13 +221,22 @@ def evaluate( logger.info(f"Skipping dataset '{k}' (not to be evaluated)") continue logger.info(f"Analyzing '{k}' set...") - run(v, k, predictions_folder, output_folder, overlayed, threshold, - steps=steps) + run( + v, + k, + _get_folder(predictions_folder, k), + output_folder, + overlayed, + threshold, + steps=steps, + ) second = second_annotator.get(k) if second is not None: if not second.all_keys_match(v): - logger.warning(f"Key mismatch between `dataset[{k}]` and " \ - f"`second_annotator[{k}]` - skipping " \ - f"second-annotator comparisons for {k} subset") + logger.warning( + f"Key mismatch between `dataset[{k}]` and " + f"`second_annotator[{k}]` - skipping " + f"second-annotator comparisons for {k} subset" + ) else: compare_annotators(v, second, k, output_folder, overlayed) diff --git a/bob/ip/binseg/script/predict.py b/bob/ip/binseg/script/predict.py index e096b60dc3c236d442e6af66a0f48471af6183a6..06a2edbf6156ea7a402ee7be858ac01b9cd6d207 100644 --- a/bob/ip/binseg/script/predict.py +++ b/bob/ip/binseg/script/predict.py @@ -145,4 +145,8 @@ def predict(output_folder, model, dataset, batch_size, device, weight, shuffle=False, pin_memory=torch.cuda.is_available(), ) - run(model, data_loader, device, output_folder, overlayed) + # this avoids collisions if we have, e.g., multi-resolution versions + # of the same dataset being evaluated, or datasets for which filenames + # may match. + use_output_folder = os.path.join(output_folder, k) + run(model, data_loader, device, use_output_folder, overlayed)