From 18b5d9e621c5bcbac5996fe1a0497975607161ec Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 28 May 2024 14:21:13 +0200
Subject: [PATCH] [experiment] Check validation split exists for evaluation
 threshold

If no split named "validation" exists, the experiment will use the
"train" split to compute the threshold.
---
 .../libs/classification/scripts/experiment.py   | 17 +++++++++++++++--
 .../libs/segmentation/scripts/experiment.py     | 17 +++++++++++++++--
 2 files changed, 30 insertions(+), 4 deletions(-)

diff --git a/src/mednet/libs/classification/scripts/experiment.py b/src/mednet/libs/classification/scripts/experiment.py
index 5fdf8c6c..f15f4c7d 100644
--- a/src/mednet/libs/classification/scripts/experiment.py
+++ b/src/mednet/libs/classification/scripts/experiment.py
@@ -2,6 +2,7 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import json
 from datetime import datetime
 
 import click
@@ -133,11 +134,23 @@ def experiment(
 
     from .evaluate import evaluate
 
+    predictions_file = predictions_output / "predictions.json"
+
+    with (predictions_output / "predictions.json").open() as pf:
+        splits = json.load(pf).keys()
+
+        if "validation" in splits:
+            evaluation_threshold = "validation"
+        elif "train" in splits:
+            evaluation_threshold = "train"
+        else:
+            evaluation_threshold = None
+
     ctx.invoke(
         evaluate,
-        predictions=predictions_output / "predictions.json",
+        predictions=predictions_file,
         output_folder=output_folder,
-        threshold="validation",
+        threshold=evaluation_threshold,
     )
 
     evaluation_stop_timestamp = datetime.now()
diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py
index 0f0db12f..ea74b2dd 100644
--- a/src/mednet/libs/segmentation/scripts/experiment.py
+++ b/src/mednet/libs/segmentation/scripts/experiment.py
@@ -2,6 +2,7 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import json
 from datetime import datetime
 
 import click
@@ -138,11 +139,23 @@ def experiment(
 
     evaluation_output = output_folder / "evaluation"
 
+    predictions_file = predictions_output / "predictions.json"
+
+    with (predictions_output / "predictions.json").open() as pf:
+        splits = json.load(pf).keys()
+
+        if "validation" in splits:
+            evaluation_threshold = "validation"
+        elif "train" in splits:
+            evaluation_threshold = "train"
+        else:
+            evaluation_threshold = None
+
     ctx.invoke(
         evaluate,
-        predictions=predictions_output / "predictions.json",
+        predictions=predictions_file,
         output_folder=evaluation_output,
-        threshold=0.5,
+        threshold=evaluation_threshold,
     )
 
     evaluation_stop_timestamp = datetime.now()
-- 
GitLab