Skip to content
Snippets Groups Projects
Commit 0a858361 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Updated accelerator selection during prediction

parent 55f4eac1
No related branches found
No related tags found
1 merge request!4Moved code to lightning
Pipeline #73168 failed
......@@ -7,12 +7,13 @@ import os
from pytorch_lightning import Trainer
from ..utils.accelerator import AcceleratorProcessor
from .callbacks import PredictionsWriter
logger = logging.getLogger(__name__)
def run(model, data_loader, name, device, output_folder, grad_cams=False):
def run(model, data_loader, name, accelerator, output_folder, grad_cams=False):
"""Runs inference on input data, outputs HDF5 files with predictions.
Parameters
......@@ -26,8 +27,8 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False):
the local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files.
device : str
device to use ``cpu`` or ``cuda:0``
accelerator : str
accelerator to use
output_folder : str
folder where to store output prediction and model
......@@ -48,14 +49,21 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False):
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
logger.info(f"Device: {device}")
accelerator_processor = AcceleratorProcessor(accelerator)
if accelerator_processor.device is None:
devices = "auto"
else:
devices = accelerator_processor.device
logger.info(f"Device: {devices}")
logfile_name = os.path.join(output_folder, "predictions.csv")
logfile_fields = ("filename", "likelihood", "ground_truth")
trainer = Trainer(
accelerator="auto",
devices="auto",
accelerator=accelerator_processor.accelerator,
devices=devices,
callbacks=[
PredictionsWriter(
logfile_name=logfile_name,
......
......@@ -62,9 +62,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption,
)
@click.option(
"--device",
"-d",
help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
"--accelerator",
"-a",
help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available',
show_default=True,
required=True,
default="cpu",
......@@ -98,7 +98,7 @@ def predict(
model,
dataset,
batch_size,
device,
accelerator,
weight,
relevance_analysis,
grad_cams,
......@@ -154,7 +154,7 @@ def predict(
pin_memory=torch.cuda.is_available(),
)
predictions = run(
model, data_loader, k, device, output_folder, grad_cams
model, data_loader, k, accelerator, output_folder, grad_cams
)
# Relevance analysis using permutation feature importance
......@@ -189,7 +189,11 @@ def predict(
)
predictions_with_mean = run(
model, data_loader, k, device, output_folder + "_temp"
model,
data_loader,
k,
accelerator,
output_folder + "_temp",
)
# Compute MSE between original and new predictions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment