# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import logging import os from lightning.pytorch import Trainer from ..utils.accelerator import AcceleratorProcessor from .callbacks import PredictionsWriter logger = logging.getLogger(__name__) def run(model, data_loader, name, accelerator, output_folder, grad_cams=False): """Runs inference on input data, outputs csv files with predictions. Parameters --------- model : :py:class:`torch.nn.Module` Neural network model (e.g. pasa). data_loader : py:class:`torch.torch.utils.data.DataLoader` The pytorch Dataloader used to iterate over batches. name : str The local name of this dataset (e.g. ``train``, or ``test``), to be used when saving measures files. accelerator : str A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) output_folder : str Directory in which the results will be saved. grad_cams : bool If we export grad cams for every prediction (must be used along a batch size of 1 with the DensenetRS model). Returns ------- all_predictions : list All the predictions associated with filename and ground truth. """ output_folder = os.path.join(output_folder, name) logger.info(f"Output folder: {output_folder}") os.makedirs(output_folder, exist_ok=True) accelerator_processor = AcceleratorProcessor(accelerator) if accelerator_processor.device is None: devices = "auto" else: devices = accelerator_processor.device logger.info(f"Device: {devices}") logfile_name = os.path.join(output_folder, "predictions.csv") logfile_fields = ("filename", "likelihood", "ground_truth") trainer = Trainer( accelerator=accelerator_processor.accelerator, devices=devices, callbacks=[ PredictionsWriter( logfile_name=logfile_name, logfile_fields=logfile_fields, write_interval="epoch", ), ], ) all_predictions = trainer.predict(model, data_loader) return all_predictions