-
Daniel CARRON authoredDaniel CARRON authored
predictor.py 2.22 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import os
from lightning.pytorch import Trainer
from ..utils.accelerator import AcceleratorProcessor
from .callbacks import PredictionsWriter
logger = logging.getLogger(__name__)
def run(model, data_loader, name, accelerator, output_folder, grad_cams=False):
"""Runs inference on input data, outputs csv files with predictions.
Parameters
---------
model : :py:class:`torch.nn.Module`
Neural network model (e.g. pasa).
data_loader : py:class:`torch.torch.utils.data.DataLoader`
The pytorch Dataloader used to iterate over batches.
name : str
The local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files.
accelerator : str
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
output_folder : str
Directory in which the results will be saved.
grad_cams : bool
If we export grad cams for every prediction (must be used along
a batch size of 1 with the DensenetRS model).
Returns
-------
all_predictions : list
All the predictions associated with filename and ground truth.
"""
output_folder = os.path.join(output_folder, name)
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
accelerator_processor = AcceleratorProcessor(accelerator)
if accelerator_processor.device is None:
devices = "auto"
else:
devices = accelerator_processor.device
logger.info(f"Device: {devices}")
logfile_name = os.path.join(output_folder, "predictions.csv")
logfile_fields = ("filename", "likelihood", "ground_truth")
trainer = Trainer(
accelerator=accelerator_processor.accelerator,
devices=devices,
callbacks=[
PredictionsWriter(
logfile_name=logfile_name,
logfile_fields=logfile_fields,
write_interval="epoch",
),
],
)
all_predictions = trainer.predict(model, data_loader)
return all_predictions