diff --git a/src/mednet/scripts/saliency/generate.py b/src/mednet/scripts/saliency/generate.py
index 71f248fa807fdcfb90bc23f9a8e9efb0fd635cf7..649ad96bf15a99126266568b596073c611cf04d7 100644
--- a/src/mednet/scripts/saliency/generate.py
+++ b/src/mednet/scripts/saliency/generate.py
@@ -180,6 +180,7 @@ def generate(
     datamodule.cache_samples = cache_samples
     datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
+    datamodule.batch_size = 1
 
     datamodule.prepare_data()
     datamodule.setup(stage="predict")