diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py index c1c9935d576f4a46dfef9c8aba22bda905157260..c68ee79666eceeb242e5471cba10f256fd1246dd 100644 --- a/src/ptbench/scripts/evaluate.py +++ b/src/ptbench/scripts/evaluate.py @@ -146,12 +146,13 @@ def evaluate( from ..engine.evaluator import run - datamodule.prepare_data() - datamodule.setup(stage="test") - datamodule.set_chunk_size(1, 1) + datamodule.model_transforms = [] + + datamodule.prepare_data() + datamodule.setup(stage="predict") - dataloader = datamodule.test_dataloader() + dataloader = datamodule.predict_dataloader() threshold = _validate_threshold(threshold, dataloader) @@ -176,7 +177,7 @@ def evaluate( if k.startswith("_"): logger.info(f"Skipping dataset '{k}' (not to be evaluated)") continue - logger.info(f"Analyzing '{threshold}' set...") + logger.info(f"Analyzing '{k}' set...") run( v, k,