Skip to content
Snippets Groups Projects
Commit 944bba98 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[classification.saliency] Fix wrong batch size during generation

parent 53d2415a
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -204,6 +204,7 @@ def completeness( ...@@ -204,6 +204,7 @@ def completeness(
import json import json
from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.scripts.predict import setup_datamodule
from mednet.libs.common.utils.checkpointer import ( from mednet.libs.common.utils.checkpointer import (
get_checkpoint_to_run_inference, get_checkpoint_to_run_inference,
) )
...@@ -222,11 +223,7 @@ def completeness( ...@@ -222,11 +223,7 @@ def completeness(
device_manager = DeviceManager(device) device_manager = DeviceManager(device)
datamodule.cache_samples = cache_samples datamodule.cache_samples = cache_samples
datamodule.parallel = parallel setup_datamodule(datamodule, model, 1, parallel)
datamodule.model_transforms = model.model_transforms
datamodule.prepare_data()
datamodule.setup(stage="predict")
if weight.is_dir(): if weight.is_dir():
weight = get_checkpoint_to_run_inference(weight) weight = get_checkpoint_to_run_inference(weight)
......
...@@ -170,6 +170,7 @@ def generate( ...@@ -170,6 +170,7 @@ def generate(
""" """
from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.scripts.predict import setup_datamodule
from mednet.libs.common.utils.checkpointer import ( from mednet.libs.common.utils.checkpointer import (
get_checkpoint_to_run_inference, get_checkpoint_to_run_inference,
) )
...@@ -182,12 +183,8 @@ def generate( ...@@ -182,12 +183,8 @@ def generate(
device_manager = DeviceManager(device) device_manager = DeviceManager(device)
datamodule.cache_samples = cache_samples datamodule.cache_samples = cache_samples
datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms datamodule.model_transforms = model.model_transforms
datamodule.batch_size = 1 setup_datamodule(datamodule, model, 1, parallel)
datamodule.prepare_data()
datamodule.setup(stage="predict")
if weight.is_dir(): if weight.is_dir():
weight = get_checkpoint_to_run_inference(weight) weight = get_checkpoint_to_run_inference(weight)
......
...@@ -123,6 +123,7 @@ def interpretability( ...@@ -123,6 +123,7 @@ def interpretability(
from ...engine.saliency.interpretability import run from ...engine.saliency.interpretability import run
datamodule.batch_size = 1
datamodule.model_transforms = model.transforms datamodule.model_transforms = model.transforms
datamodule.prepare_data() datamodule.prepare_data()
datamodule.setup(stage="predict") datamodule.setup(stage="predict")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment