diff --git a/src/mednet/scripts/saliency/generate.py b/src/mednet/scripts/saliency/generate.py index 71f248fa807fdcfb90bc23f9a8e9efb0fd635cf7..649ad96bf15a99126266568b596073c611cf04d7 100644 --- a/src/mednet/scripts/saliency/generate.py +++ b/src/mednet/scripts/saliency/generate.py @@ -180,6 +180,7 @@ def generate( datamodule.cache_samples = cache_samples datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms + datamodule.batch_size = 1 datamodule.prepare_data() datamodule.setup(stage="predict")