From 18b5d9e621c5bcbac5996fe1a0497975607161ec Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 28 May 2024 14:21:13 +0200 Subject: [PATCH] [experiment] Check validation split exists for evaluation threshold If no split named "validation" exists, the experiment will use the "train" split to compute the threshold. --- .../libs/classification/scripts/experiment.py | 17 +++++++++++++++-- .../libs/segmentation/scripts/experiment.py | 17 +++++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/mednet/libs/classification/scripts/experiment.py b/src/mednet/libs/classification/scripts/experiment.py index 5fdf8c6c..f15f4c7d 100644 --- a/src/mednet/libs/classification/scripts/experiment.py +++ b/src/mednet/libs/classification/scripts/experiment.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import json from datetime import datetime import click @@ -133,11 +134,23 @@ def experiment( from .evaluate import evaluate + predictions_file = predictions_output / "predictions.json" + + with (predictions_output / "predictions.json").open() as pf: + splits = json.load(pf).keys() + + if "validation" in splits: + evaluation_threshold = "validation" + elif "train" in splits: + evaluation_threshold = "train" + else: + evaluation_threshold = None + ctx.invoke( evaluate, - predictions=predictions_output / "predictions.json", + predictions=predictions_file, output_folder=output_folder, - threshold="validation", + threshold=evaluation_threshold, ) evaluation_stop_timestamp = datetime.now() diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py index 0f0db12f..ea74b2dd 100644 --- a/src/mednet/libs/segmentation/scripts/experiment.py +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import json from datetime import datetime import click @@ -138,11 +139,23 @@ def experiment( evaluation_output = output_folder / "evaluation" + predictions_file = predictions_output / "predictions.json" + + with (predictions_output / "predictions.json").open() as pf: + splits = json.load(pf).keys() + + if "validation" in splits: + evaluation_threshold = "validation" + elif "train" in splits: + evaluation_threshold = "train" + else: + evaluation_threshold = None + ctx.invoke( evaluate, - predictions=predictions_output / "predictions.json", + predictions=predictions_file, output_folder=evaluation_output, - threshold=0.5, + threshold=evaluation_threshold, ) evaluation_stop_timestamp = datetime.now() -- GitLab