Skip to content
Snippets Groups Projects
Commit 0a858361 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Updated accelerator selection during prediction

parent 55f4eac1
No related branches found
No related tags found
Loading
Pipeline #73168 failed
...@@ -7,12 +7,13 @@ import os ...@@ -7,12 +7,13 @@ import os
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from ..utils.accelerator import AcceleratorProcessor
from .callbacks import PredictionsWriter from .callbacks import PredictionsWriter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run(model, data_loader, name, device, output_folder, grad_cams=False): def run(model, data_loader, name, accelerator, output_folder, grad_cams=False):
"""Runs inference on input data, outputs HDF5 files with predictions. """Runs inference on input data, outputs HDF5 files with predictions.
Parameters Parameters
...@@ -26,8 +27,8 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False): ...@@ -26,8 +27,8 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False):
the local name of this dataset (e.g. ``train``, or ``test``), to be the local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files. used when saving measures files.
device : str accelerator : str
device to use ``cpu`` or ``cuda:0`` accelerator to use
output_folder : str output_folder : str
folder where to store output prediction and model folder where to store output prediction and model
...@@ -48,14 +49,21 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False): ...@@ -48,14 +49,21 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False):
logger.info(f"Output folder: {output_folder}") logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
logger.info(f"Device: {device}") accelerator_processor = AcceleratorProcessor(accelerator)
if accelerator_processor.device is None:
devices = "auto"
else:
devices = accelerator_processor.device
logger.info(f"Device: {devices}")
logfile_name = os.path.join(output_folder, "predictions.csv") logfile_name = os.path.join(output_folder, "predictions.csv")
logfile_fields = ("filename", "likelihood", "ground_truth") logfile_fields = ("filename", "likelihood", "ground_truth")
trainer = Trainer( trainer = Trainer(
accelerator="auto", accelerator=accelerator_processor.accelerator,
devices="auto", devices=devices,
callbacks=[ callbacks=[
PredictionsWriter( PredictionsWriter(
logfile_name=logfile_name, logfile_name=logfile_name,
......
...@@ -62,9 +62,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -62,9 +62,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--device", "--accelerator",
"-d", "-a",
help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available',
show_default=True, show_default=True,
required=True, required=True,
default="cpu", default="cpu",
...@@ -98,7 +98,7 @@ def predict( ...@@ -98,7 +98,7 @@ def predict(
model, model,
dataset, dataset,
batch_size, batch_size,
device, accelerator,
weight, weight,
relevance_analysis, relevance_analysis,
grad_cams, grad_cams,
...@@ -154,7 +154,7 @@ def predict( ...@@ -154,7 +154,7 @@ def predict(
pin_memory=torch.cuda.is_available(), pin_memory=torch.cuda.is_available(),
) )
predictions = run( predictions = run(
model, data_loader, k, device, output_folder, grad_cams model, data_loader, k, accelerator, output_folder, grad_cams
) )
# Relevance analysis using permutation feature importance # Relevance analysis using permutation feature importance
...@@ -189,7 +189,11 @@ def predict( ...@@ -189,7 +189,11 @@ def predict(
) )
predictions_with_mean = run( predictions_with_mean = run(
model, data_loader, k, device, output_folder + "_temp" model,
data_loader,
k,
accelerator,
output_folder + "_temp",
) )
# Compute MSE between original and new predictions # Compute MSE between original and new predictions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment