Skip to content
Snippets Groups Projects

Moved code to lightning

Merged Daniel CARRON requested to merge move-to-lightning into main
1 unresolved thread
Files
2
@@ -7,12 +7,13 @@ import os
from pytorch_lightning import Trainer
from ..utils.accelerator import AcceleratorProcessor
from .callbacks import PredictionsWriter
logger = logging.getLogger(__name__)
def run(model, data_loader, name, device, output_folder, grad_cams=False):
def run(model, data_loader, name, accelerator, output_folder, grad_cams=False):
"""Runs inference on input data, outputs HDF5 files with predictions.
Parameters
@@ -26,8 +27,8 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False):
the local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files.
device : str
device to use ``cpu`` or ``cuda:0``
accelerator : str
accelerator to use
output_folder : str
folder where to store output prediction and model
@@ -48,14 +49,21 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False):
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
logger.info(f"Device: {device}")
accelerator_processor = AcceleratorProcessor(accelerator)
if accelerator_processor.device is None:
devices = "auto"
else:
devices = accelerator_processor.device
logger.info(f"Device: {devices}")
logfile_name = os.path.join(output_folder, "predictions.csv")
logfile_fields = ("filename", "likelihood", "ground_truth")
trainer = Trainer(
accelerator="auto",
devices="auto",
accelerator=accelerator_processor.accelerator,
devices=devices,
callbacks=[
PredictionsWriter(
logfile_name=logfile_name,
Loading