diff --git a/src/mednet/libs/classification/scripts/experiment.py b/src/mednet/libs/classification/scripts/experiment.py index 5fdf8c6c94d5a58253b5e11b301990f2d0d37384..f15f4c7d89ae7d6a54da74fc8594320d06e491c9 100644 --- a/src/mednet/libs/classification/scripts/experiment.py +++ b/src/mednet/libs/classification/scripts/experiment.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import json from datetime import datetime import click @@ -133,11 +134,23 @@ def experiment( from .evaluate import evaluate + predictions_file = predictions_output / "predictions.json" + + with (predictions_output / "predictions.json").open() as pf: + splits = json.load(pf).keys() + + if "validation" in splits: + evaluation_threshold = "validation" + elif "train" in splits: + evaluation_threshold = "train" + else: + evaluation_threshold = None + ctx.invoke( evaluate, - predictions=predictions_output / "predictions.json", + predictions=predictions_file, output_folder=output_folder, - threshold="validation", + threshold=evaluation_threshold, ) evaluation_stop_timestamp = datetime.now() diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py index 0f0db12fbcb183192446b1943f29c9b6da63b44d..ea74b2dd25d7a158b0faa3668e6229840bd47ff1 100644 --- a/src/mednet/libs/segmentation/scripts/experiment.py +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import json from datetime import datetime import click @@ -138,11 +139,23 @@ def experiment( evaluation_output = output_folder / "evaluation" + predictions_file = predictions_output / "predictions.json" + + with (predictions_output / "predictions.json").open() as pf: + splits = json.load(pf).keys() + + if "validation" in splits: + evaluation_threshold = "validation" + elif "train" in splits: + evaluation_threshold = "train" + else: + evaluation_threshold = None + ctx.invoke( evaluate, - predictions=predictions_output / "predictions.json", + predictions=predictions_file, output_folder=evaluation_output, - threshold=0.5, + threshold=evaluation_threshold, ) evaluation_stop_timestamp = datetime.now()