Skip to content
Snippets Groups Projects
predictor.py 2.22 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import logging
import os

from lightning.pytorch import Trainer

from ..utils.accelerator import AcceleratorProcessor
from .callbacks import PredictionsWriter

logger = logging.getLogger(__name__)


def run(model, data_loader, name, accelerator, output_folder, grad_cams=False):
    """Runs inference on input data, outputs csv files with predictions.

    Parameters
    ---------
    model : :py:class:`torch.nn.Module`
        Neural network model (e.g. pasa).

    data_loader : py:class:`torch.torch.utils.data.DataLoader`
        The pytorch Dataloader used to iterate over batches.

    name : str
        The local name of this dataset (e.g. ``train``, or ``test``), to be
        used when saving measures files.

    accelerator : str
        A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)

    output_folder : str
        Directory in which the results will be saved.

    grad_cams : bool
        If we export grad cams for every prediction (must be used along
        a batch size of 1 with the DensenetRS model).

    Returns
    -------

    all_predictions : list
        All the predictions associated with filename and ground truth.
    """
    output_folder = os.path.join(output_folder, name)

    logger.info(f"Output folder: {output_folder}")
    os.makedirs(output_folder, exist_ok=True)

    accelerator_processor = AcceleratorProcessor(accelerator)

    if accelerator_processor.device is None:
        devices = "auto"
    else:
        devices = accelerator_processor.device

    logger.info(f"Device: {devices}")

    logfile_name = os.path.join(output_folder, "predictions.csv")
    logfile_fields = ("filename", "likelihood", "ground_truth")

    trainer = Trainer(
        accelerator=accelerator_processor.accelerator,
        devices=devices,
        callbacks=[
            PredictionsWriter(
                logfile_name=logfile_name,
                logfile_fields=logfile_fields,
                write_interval="epoch",
            ),
        ],
    )

    all_predictions = trainer.predict(model, data_loader)

    return all_predictions