diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index c2ce035f73a37fad996181adb5d9412fb8e26556..2da897a7753c474e2be19edff907e13597b012b5 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -407,8 +407,7 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter): for dataloader_idx, dataloader_name in enumerate(dataloader_names): logfile = os.path.join( self.output_dir, - f"predictions_{dataloader_name}", - "predictions.csv", + f"{dataloader_name}.csv", ) os.makedirs(os.path.dirname(logfile), exist_ok=True)