From 6874800837a4aa68d2bcd349d01c47ec1fcdd628 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 24 Jul 2023 11:17:44 +0200 Subject: [PATCH] Evaluate on all prediction dataloaders --- src/ptbench/scripts/evaluate.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py index c1c9935d..c68ee796 100644 --- a/src/ptbench/scripts/evaluate.py +++ b/src/ptbench/scripts/evaluate.py @@ -146,12 +146,13 @@ def evaluate( from ..engine.evaluator import run - datamodule.prepare_data() - datamodule.setup(stage="test") - datamodule.set_chunk_size(1, 1) + datamodule.model_transforms = [] + + datamodule.prepare_data() + datamodule.setup(stage="predict") - dataloader = datamodule.test_dataloader() + dataloader = datamodule.predict_dataloader() threshold = _validate_threshold(threshold, dataloader) @@ -176,7 +177,7 @@ def evaluate( if k.startswith("_"): logger.info(f"Skipping dataset '{k}' (not to be evaluated)") continue - logger.info(f"Analyzing '{threshold}' set...") + logger.info(f"Analyzing '{k}' set...") run( v, k, -- GitLab