diff --git a/src/mednet/libs/classification/scripts/saliency/completeness.py b/src/mednet/libs/classification/scripts/saliency/completeness.py index fa2d243d6287df127f6da33cb0a31f7ed2b405f6..3ee81288101676fff4714adcfb08d1f41c20c378 100644 --- a/src/mednet/libs/classification/scripts/saliency/completeness.py +++ b/src/mednet/libs/classification/scripts/saliency/completeness.py @@ -204,6 +204,7 @@ def completeness( import json from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.scripts.predict import setup_datamodule from mednet.libs.common.utils.checkpointer import ( get_checkpoint_to_run_inference, ) @@ -222,11 +223,7 @@ def completeness( device_manager = DeviceManager(device) datamodule.cache_samples = cache_samples - datamodule.parallel = parallel - datamodule.model_transforms = model.model_transforms - - datamodule.prepare_data() - datamodule.setup(stage="predict") + setup_datamodule(datamodule, model, 1, parallel) if weight.is_dir(): weight = get_checkpoint_to_run_inference(weight) diff --git a/src/mednet/libs/classification/scripts/saliency/generate.py b/src/mednet/libs/classification/scripts/saliency/generate.py index 3dfd8d84d29c3142f9081e4d6ffe56a38ffa9726..825c7b3e2d22cf8b487143606420e1c050c71454 100644 --- a/src/mednet/libs/classification/scripts/saliency/generate.py +++ b/src/mednet/libs/classification/scripts/saliency/generate.py @@ -170,6 +170,7 @@ def generate( """ from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.scripts.predict import setup_datamodule from mednet.libs.common.utils.checkpointer import ( get_checkpoint_to_run_inference, ) @@ -182,12 +183,8 @@ def generate( device_manager = DeviceManager(device) datamodule.cache_samples = cache_samples - datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms - datamodule.batch_size = 1 - - datamodule.prepare_data() - datamodule.setup(stage="predict") + setup_datamodule(datamodule, model, 1, parallel) if weight.is_dir(): weight = get_checkpoint_to_run_inference(weight) diff --git a/src/mednet/libs/classification/scripts/saliency/interpretability.py b/src/mednet/libs/classification/scripts/saliency/interpretability.py index c4b2ba525ed720ac1c8bc4fc2651cc2c7e6fb3d4..0d0f556b49e6b6dd8f0a4c9d68adbfa9d5003192 100644 --- a/src/mednet/libs/classification/scripts/saliency/interpretability.py +++ b/src/mednet/libs/classification/scripts/saliency/interpretability.py @@ -123,6 +123,7 @@ def interpretability( from ...engine.saliency.interpretability import run + datamodule.batch_size = 1 datamodule.model_transforms = model.transforms datamodule.prepare_data() datamodule.setup(stage="predict")