From 6874800837a4aa68d2bcd349d01c47ec1fcdd628 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 24 Jul 2023 11:17:44 +0200
Subject: [PATCH] Evaluate on all prediction dataloaders

---
 src/ptbench/scripts/evaluate.py | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py
index c1c9935d..c68ee796 100644
--- a/src/ptbench/scripts/evaluate.py
+++ b/src/ptbench/scripts/evaluate.py
@@ -146,12 +146,13 @@ def evaluate(
 
     from ..engine.evaluator import run
 
-    datamodule.prepare_data()
-    datamodule.setup(stage="test")
-
     datamodule.set_chunk_size(1, 1)
+    datamodule.model_transforms = []
+
+    datamodule.prepare_data()
+    datamodule.setup(stage="predict")
 
-    dataloader = datamodule.test_dataloader()
+    dataloader = datamodule.predict_dataloader()
 
     threshold = _validate_threshold(threshold, dataloader)
 
@@ -176,7 +177,7 @@ def evaluate(
         if k.startswith("_"):
             logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
             continue
-        logger.info(f"Analyzing '{threshold}' set...")
+        logger.info(f"Analyzing '{k}' set...")
         run(
             v,
             k,
-- 
GitLab