diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 31c3e4c67fc355a6fd077cebbe8e15bc9b7b4911..8afc8b8532eb7920bebfa554f05ba8e6bf8faf23 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -7,12 +7,13 @@ import os from pytorch_lightning import Trainer +from ..utils.accelerator import AcceleratorProcessor from .callbacks import PredictionsWriter logger = logging.getLogger(__name__) -def run(model, data_loader, name, device, output_folder, grad_cams=False): +def run(model, data_loader, name, accelerator, output_folder, grad_cams=False): """Runs inference on input data, outputs HDF5 files with predictions. Parameters @@ -26,8 +27,8 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False): the local name of this dataset (e.g. ``train``, or ``test``), to be used when saving measures files. - device : str - device to use ``cpu`` or ``cuda:0`` + accelerator : str + accelerator to use output_folder : str folder where to store output prediction and model @@ -48,14 +49,21 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False): logger.info(f"Output folder: {output_folder}") os.makedirs(output_folder, exist_ok=True) - logger.info(f"Device: {device}") + accelerator_processor = AcceleratorProcessor(accelerator) + + if accelerator_processor.device is None: + devices = "auto" + else: + devices = accelerator_processor.device + + logger.info(f"Device: {devices}") logfile_name = os.path.join(output_folder, "predictions.csv") logfile_fields = ("filename", "likelihood", "ground_truth") trainer = Trainer( - accelerator="auto", - devices="auto", + accelerator=accelerator_processor.accelerator, + devices=devices, callbacks=[ PredictionsWriter( logfile_name=logfile_name, diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 65336ac138bd58cf30f9879c562c32a5614591e1..689bca1bc7a201b9b5855716f6d19120e4630437 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -62,9 +62,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @click.option( - "--device", - "-d", - help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', + "--accelerator", + "-a", + help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available', show_default=True, required=True, default="cpu", @@ -98,7 +98,7 @@ def predict( model, dataset, batch_size, - device, + accelerator, weight, relevance_analysis, grad_cams, @@ -154,7 +154,7 @@ def predict( pin_memory=torch.cuda.is_available(), ) predictions = run( - model, data_loader, k, device, output_folder, grad_cams + model, data_loader, k, accelerator, output_folder, grad_cams ) # Relevance analysis using permutation feature importance @@ -189,7 +189,11 @@ def predict( ) predictions_with_mean = run( - model, data_loader, k, device, output_folder + "_temp" + model, + data_loader, + k, + accelerator, + output_folder + "_temp", ) # Compute MSE between original and new predictions