From 0a8583614cd91b7a11d8ec7d6dc87f4655a50e48 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 8 May 2023 17:17:16 +0200
Subject: [PATCH] Updated accelerator selection during prediction

---
 src/ptbench/engine/predictor.py | 20 ++++++++++++++------
 src/ptbench/scripts/predict.py  | 16 ++++++++++------
 2 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py
index 31c3e4c6..8afc8b85 100644
--- a/src/ptbench/engine/predictor.py
+++ b/src/ptbench/engine/predictor.py
@@ -7,12 +7,13 @@ import os
 
 from pytorch_lightning import Trainer
 
+from ..utils.accelerator import AcceleratorProcessor
 from .callbacks import PredictionsWriter
 
 logger = logging.getLogger(__name__)
 
 
-def run(model, data_loader, name, device, output_folder, grad_cams=False):
+def run(model, data_loader, name, accelerator, output_folder, grad_cams=False):
     """Runs inference on input data, outputs HDF5 files with predictions.
 
     Parameters
@@ -26,8 +27,8 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False):
         the local name of this dataset (e.g. ``train``, or ``test``), to be
         used when saving measures files.
 
-    device : str
-        device to use ``cpu`` or ``cuda:0``
+    accelerator : str
+        accelerator to use
 
     output_folder : str
         folder where to store output prediction and model
@@ -48,14 +49,21 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False):
     logger.info(f"Output folder: {output_folder}")
     os.makedirs(output_folder, exist_ok=True)
 
-    logger.info(f"Device: {device}")
+    accelerator_processor = AcceleratorProcessor(accelerator)
+
+    if accelerator_processor.device is None:
+        devices = "auto"
+    else:
+        devices = accelerator_processor.device
+
+    logger.info(f"Device: {devices}")
 
     logfile_name = os.path.join(output_folder, "predictions.csv")
     logfile_fields = ("filename", "likelihood", "ground_truth")
 
     trainer = Trainer(
-        accelerator="auto",
-        devices="auto",
+        accelerator=accelerator_processor.accelerator,
+        devices=devices,
         callbacks=[
             PredictionsWriter(
                 logfile_name=logfile_name,
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index 65336ac1..689bca1b 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -62,9 +62,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @click.option(
-    "--device",
-    "-d",
-    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
+    "--accelerator",
+    "-a",
+    help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available',
     show_default=True,
     required=True,
     default="cpu",
@@ -98,7 +98,7 @@ def predict(
     model,
     dataset,
     batch_size,
-    device,
+    accelerator,
     weight,
     relevance_analysis,
     grad_cams,
@@ -154,7 +154,7 @@ def predict(
             pin_memory=torch.cuda.is_available(),
         )
         predictions = run(
-            model, data_loader, k, device, output_folder, grad_cams
+            model, data_loader, k, accelerator, output_folder, grad_cams
         )
 
         # Relevance analysis using permutation feature importance
@@ -189,7 +189,11 @@ def predict(
                     )
 
                     predictions_with_mean = run(
-                        model, data_loader, k, device, output_folder + "_temp"
+                        model,
+                        data_loader,
+                        k,
+                        accelerator,
+                        output_folder + "_temp",
                     )
 
                     # Compute MSE between original and new predictions
-- 
GitLab