From d076a17652b46ef0a3988781c6478f63e6120cac Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Fri, 18 Aug 2023 08:57:19 +0200 Subject: [PATCH] [scripts.experiment] Resync with changes to other scripts --- src/ptbench/scripts/experiment.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/ptbench/scripts/experiment.py b/src/ptbench/scripts/experiment.py index 3bf6a50c..44c9a40b 100644 --- a/src/ptbench/scripts/experiment.py +++ b/src/ptbench/scripts/experiment.py @@ -300,15 +300,17 @@ def experiment( if not os.path.exists(model_file): model_file = os.path.join(train_output_folder, "model_final_epoch.ckpt") - predictions_folder = os.path.join(output_folder, "predictions") + predictions_output = os.path.join(output_folder, "predictions.json") ctx.invoke( predict, - output_folder=predictions_folder, + output=predictions_output, model=model, datamodule=datamodule, device=device, weight=model_file, + batch_size=batch_size, + parallel=parallel, ) logger.info("Ended predicting") @@ -322,7 +324,7 @@ def experiment( ctx.invoke( evaluate, output_folder=evaluations_folder, - predictions=os.path.join(predictions_folder, "predictions.json"), + predictions=predictions_output, threshold="validation", ) -- GitLab