diff --git a/src/ptbench/scripts/experiment.py b/src/ptbench/scripts/experiment.py index 3bf6a50ca4a2104fbed5fafea0fdbb298788cb4b..44c9a40babd2e432381182354cda456d633583e0 100644 --- a/src/ptbench/scripts/experiment.py +++ b/src/ptbench/scripts/experiment.py @@ -300,15 +300,17 @@ def experiment( if not os.path.exists(model_file): model_file = os.path.join(train_output_folder, "model_final_epoch.ckpt") - predictions_folder = os.path.join(output_folder, "predictions") + predictions_output = os.path.join(output_folder, "predictions.json") ctx.invoke( predict, - output_folder=predictions_folder, + output=predictions_output, model=model, datamodule=datamodule, device=device, weight=model_file, + batch_size=batch_size, + parallel=parallel, ) logger.info("Ended predicting") @@ -322,7 +324,7 @@ def experiment( ctx.invoke( evaluate, output_folder=evaluations_folder, - predictions=os.path.join(predictions_folder, "predictions.json"), + predictions=predictions_output, threshold="validation", )