From d076a17652b46ef0a3988781c6478f63e6120cac Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 18 Aug 2023 08:57:19 +0200
Subject: [PATCH] [scripts.experiment] Resync with changes to other scripts

---
 src/ptbench/scripts/experiment.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/ptbench/scripts/experiment.py b/src/ptbench/scripts/experiment.py
index 3bf6a50c..44c9a40b 100644
--- a/src/ptbench/scripts/experiment.py
+++ b/src/ptbench/scripts/experiment.py
@@ -300,15 +300,17 @@ def experiment(
     if not os.path.exists(model_file):
         model_file = os.path.join(train_output_folder, "model_final_epoch.ckpt")
 
-    predictions_folder = os.path.join(output_folder, "predictions")
+    predictions_output = os.path.join(output_folder, "predictions.json")
 
     ctx.invoke(
         predict,
-        output_folder=predictions_folder,
+        output=predictions_output,
         model=model,
         datamodule=datamodule,
         device=device,
         weight=model_file,
+        batch_size=batch_size,
+        parallel=parallel,
     )
 
     logger.info("Ended predicting")
@@ -322,7 +324,7 @@ def experiment(
     ctx.invoke(
         evaluate,
         output_folder=evaluations_folder,
-        predictions=os.path.join(predictions_folder, "predictions.json"),
+        predictions=predictions_output,
         threshold="validation",
     )
 
-- 
GitLab