From 944bba9822207e72fbb39d9ecb663ae97937cb8b Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Thu, 6 Jun 2024 11:06:30 +0200 Subject: [PATCH] [classification.saliency] Fix wrong batch size during generation --- .../libs/classification/scripts/saliency/completeness.py | 7 ++----- .../libs/classification/scripts/saliency/generate.py | 7 ++----- .../classification/scripts/saliency/interpretability.py | 1 + 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/mednet/libs/classification/scripts/saliency/completeness.py b/src/mednet/libs/classification/scripts/saliency/completeness.py index fa2d243d..3ee81288 100644 --- a/src/mednet/libs/classification/scripts/saliency/completeness.py +++ b/src/mednet/libs/classification/scripts/saliency/completeness.py @@ -204,6 +204,7 @@ def completeness( import json from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.scripts.predict import setup_datamodule from mednet.libs.common.utils.checkpointer import ( get_checkpoint_to_run_inference, ) @@ -222,11 +223,7 @@ def completeness( device_manager = DeviceManager(device) datamodule.cache_samples = cache_samples - datamodule.parallel = parallel - datamodule.model_transforms = model.model_transforms - - datamodule.prepare_data() - datamodule.setup(stage="predict") + setup_datamodule(datamodule, model, 1, parallel) if weight.is_dir(): weight = get_checkpoint_to_run_inference(weight) diff --git a/src/mednet/libs/classification/scripts/saliency/generate.py b/src/mednet/libs/classification/scripts/saliency/generate.py index 3dfd8d84..825c7b3e 100644 --- a/src/mednet/libs/classification/scripts/saliency/generate.py +++ b/src/mednet/libs/classification/scripts/saliency/generate.py @@ -170,6 +170,7 @@ def generate( """ from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.scripts.predict import setup_datamodule from mednet.libs.common.utils.checkpointer import ( get_checkpoint_to_run_inference, ) @@ -182,12 +183,8 @@ def generate( device_manager = DeviceManager(device) datamodule.cache_samples = cache_samples - datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms - datamodule.batch_size = 1 - - datamodule.prepare_data() - datamodule.setup(stage="predict") + setup_datamodule(datamodule, model, 1, parallel) if weight.is_dir(): weight = get_checkpoint_to_run_inference(weight) diff --git a/src/mednet/libs/classification/scripts/saliency/interpretability.py b/src/mednet/libs/classification/scripts/saliency/interpretability.py index c4b2ba52..0d0f556b 100644 --- a/src/mednet/libs/classification/scripts/saliency/interpretability.py +++ b/src/mednet/libs/classification/scripts/saliency/interpretability.py @@ -123,6 +123,7 @@ def interpretability( from ...engine.saliency.interpretability import run + datamodule.batch_size = 1 datamodule.model_transforms = model.transforms datamodule.prepare_data() datamodule.setup(stage="predict") -- GitLab