From 0a8583614cd91b7a11d8ec7d6dc87f4655a50e48 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 8 May 2023 17:17:16 +0200 Subject: [PATCH] Updated accelerator selection during prediction --- src/ptbench/engine/predictor.py | 20 ++++++++++++++------ src/ptbench/scripts/predict.py | 16 ++++++++++------ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 31c3e4c6..8afc8b85 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -7,12 +7,13 @@ import os from pytorch_lightning import Trainer +from ..utils.accelerator import AcceleratorProcessor from .callbacks import PredictionsWriter logger = logging.getLogger(__name__) -def run(model, data_loader, name, device, output_folder, grad_cams=False): +def run(model, data_loader, name, accelerator, output_folder, grad_cams=False): """Runs inference on input data, outputs HDF5 files with predictions. Parameters @@ -26,8 +27,8 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False): the local name of this dataset (e.g. ``train``, or ``test``), to be used when saving measures files. - device : str - device to use ``cpu`` or ``cuda:0`` + accelerator : str + accelerator to use output_folder : str folder where to store output prediction and model @@ -48,14 +49,21 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False): logger.info(f"Output folder: {output_folder}") os.makedirs(output_folder, exist_ok=True) - logger.info(f"Device: {device}") + accelerator_processor = AcceleratorProcessor(accelerator) + + if accelerator_processor.device is None: + devices = "auto" + else: + devices = accelerator_processor.device + + logger.info(f"Device: {devices}") logfile_name = os.path.join(output_folder, "predictions.csv") logfile_fields = ("filename", "likelihood", "ground_truth") trainer = Trainer( - accelerator="auto", - devices="auto", + accelerator=accelerator_processor.accelerator, + devices=devices, callbacks=[ PredictionsWriter( logfile_name=logfile_name, diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 65336ac1..689bca1b 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -62,9 +62,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @click.option( - "--device", - "-d", - help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', + "--accelerator", + "-a", + help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available', show_default=True, required=True, default="cpu", @@ -98,7 +98,7 @@ def predict( model, dataset, batch_size, - device, + accelerator, weight, relevance_analysis, grad_cams, @@ -154,7 +154,7 @@ def predict( pin_memory=torch.cuda.is_available(), ) predictions = run( - model, data_loader, k, device, output_folder, grad_cams + model, data_loader, k, accelerator, output_folder, grad_cams ) # Relevance analysis using permutation feature importance @@ -189,7 +189,11 @@ def predict( ) predictions_with_mean = run( - model, data_loader, k, device, output_folder + "_temp" + model, + data_loader, + k, + accelerator, + output_folder + "_temp", ) # Compute MSE between original and new predictions -- GitLab