Skip to content
Snippets Groups Projects
Commit f39ae93f authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Functional predictions

parent cfd3773e
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -82,7 +82,8 @@ class RawDataLoader(_BaseRawDataLoader):
tensor = self.transform(
load_pil_baw(os.path.join(self.datadir, sample[0]))
)
return tensor, dict(label=sample[1]) # type: ignore[arg-type]
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int:
"""Loads a single image sample label from the disk.
......
......@@ -18,7 +18,7 @@ from torchvision import transforms
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from .raw_data_loader import raw_data_loader
from .loader import RawDataLoader
datamodule = CachingDataModule(
database_split=JSONDatabaseSplit(
......@@ -26,16 +26,10 @@ datamodule = CachingDataModule(
"default.json.bz2"
)
),
raw_data_loader=raw_data_loader,
cache_samples=False,
# train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
raw_data_loader=RawDataLoader(),
model_transforms=[
transforms.ToPILImage(),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
],
# batch_size = 1,
# batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
)
......@@ -73,3 +73,6 @@ DataLoader = torch.utils.data.DataLoader[Sample]
We iterate over Sample objects in this case.
"""
Checkpoint = dict[str, typing.Any]
"""Definition of a lightning checkpoint."""
......@@ -398,26 +398,27 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
predictions: typing.Sequence[typing.Any],
batch_indices: typing.Sequence[typing.Any] | None,
) -> None:
for dataloader_idx, dataloader_results in enumerate(predictions):
dataloader_name = list(
trainer.datamodule.predict_dataloader().keys()
)[dataloader_idx].replace("_loader", "")
dataloader_name = list(trainer.datamodule.predict_dataloader().keys())[
0
]
logfile = os.path.join(
self.output_dir, dataloader_name, "predictions.csv"
)
os.makedirs(os.path.dirname(logfile), exist_ok=True)
with open(logfile, "w") as l_f:
logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
logwriter.writeheader()
for prediction in dataloader_results:
logwriter.writerow(
{
"filename": prediction[0],
"likelihood": prediction[1].numpy(),
"ground_truth": prediction[2].numpy(),
}
)
l_f.flush()
logfile = os.path.join(
self.output_dir, f"predictions_{dataloader_name}_set.csv"
)
os.makedirs(os.path.dirname(logfile), exist_ok=True)
logger.info(f"Saving predictions in {logfile}.")
with open(logfile, "w") as l_f:
logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
logwriter.writeheader()
for prediction in predictions:
logwriter.writerow(
{
"filename": prediction[0],
"likelihood": prediction[1].numpy(),
"ground_truth": prediction[2].numpy(),
}
)
l_f.flush()
......@@ -5,15 +5,22 @@
import logging
import os
import lightning.pytorch
from lightning.pytorch import Trainer
from ..utils.accelerator import AcceleratorProcessor
from .callbacks import PredictionsWriter
from .device import DeviceManager
logger = logging.getLogger(__name__)
def run(model, datamodule, accelerator, output_folder, grad_cams=False):
def run(
model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager,
output_folder: str,
):
"""Runs inference on input data, outputs csv files with predictions.
Parameters
......@@ -21,11 +28,13 @@ def run(model, datamodule, accelerator, output_folder, grad_cams=False):
model : :py:class:`torch.nn.Module`
Neural network model (e.g. pasa).
data_loader : py:class:`torch.torch.utils.data.DataLoader`
The pytorch Dataloader used to iterate over batches.
datamodule
The lightning datamodule to use for training **and** validation
accelerator : str
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
device_manager
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a torch lightning accelerator setup.
output_folder : str
Directory in which the results will be saved.
......@@ -44,19 +53,11 @@ def run(model, datamodule, accelerator, output_folder, grad_cams=False):
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
accelerator_processor = AcceleratorProcessor(accelerator)
if accelerator_processor.device is None:
devices = "auto"
else:
devices = accelerator_processor.device
logger.info(f"Device: {devices}")
logfile_fields = ("filename", "likelihood", "ground_truth")
accelerator, devices = device_manager.lightning_accelerator()
trainer = Trainer(
accelerator=accelerator_processor.accelerator,
accelerator=accelerator,
devices=devices,
callbacks=[
PredictionsWriter(
......
......@@ -8,13 +8,12 @@ import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.nn.functional as F
import torch.optim.optimizer
import torch.utils.data
import torchvision.models as models
import torchvision.transforms
from ..data.typing import DataLoader, TransformSequence
from ..data.typing import Checkpoint, DataLoader, TransformSequence
logger = logging.getLogger(__name__)
......@@ -61,10 +60,10 @@ class Alexnet(pl.LightningModule):
def __init__(
self,
train_loss: torch.nn.Module,
validation_loss: torch.nn.Module | None,
optimizer_type: type[torch.optim.Optimizer],
optimizer_arguments: dict[str, typing.Any],
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
validation_loss: torch.nn.Module | None = None,
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
pretrained: bool = False,
):
......@@ -105,6 +104,32 @@ class Alexnet(pl.LightningModule):
return x
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning to restore your model.
If you saved something with on_save_checkpoint() this is your chance to restore this.
Parameters
----------
checkpoint:
Loaded checkpoint
"""
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save.
Parameters
----------
checkpoint:
Loaded checkpoint
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initializes the normalizer for the current model.
......@@ -214,7 +239,7 @@ class Alexnet(pl.LightningModule):
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
images = batch[0]
labels = batch[1]["label"]
names = batch[1]["names"]
names = batch[1]["name"]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
......
......@@ -8,13 +8,12 @@ import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.nn.functional as F
import torch.optim.optimizer
import torch.utils.data
import torchvision.models as models
import torchvision.transforms
from ..data.typing import DataLoader, TransformSequence
from ..data.typing import Checkpoint, DataLoader, TransformSequence
logger = logging.getLogger(__name__)
......@@ -59,12 +58,12 @@ class Densenet(pl.LightningModule):
def __init__(
self,
train_loss: torch.nn.Module,
validation_loss: torch.nn.Module | None,
optimizer_type: type[torch.optim.Optimizer],
optimizer_arguments: dict[str, typing.Any],
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
validation_loss: torch.nn.Module | None = None,
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
pretrained: bool= False,
pretrained: bool = False,
):
super().__init__()
......@@ -98,13 +97,38 @@ class Densenet(pl.LightningModule):
)
def forward(self, x):
x = self.normalizer(x) # type: ignore
x = self.model_ft(x)
return x
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning to restore your model.
If you saved something with on_save_checkpoint() this is your chance to restore this.
Parameters
----------
checkpoint:
Loaded checkpoint
"""
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save.
Parameters
----------
checkpoint:
Loaded checkpoint
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initializes the normalizer for the current model.
......@@ -216,7 +240,7 @@ class Densenet(pl.LightningModule):
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
images = batch[0]
labels = batch[1]["label"]
names = batch[1]["names"]
names = batch[1]["name"]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
......
......@@ -13,7 +13,7 @@ import torch.optim.optimizer
import torch.utils.data
import torchvision.transforms
from ..data.typing import DataLoader, TransformSequence
from ..data.typing import Checkpoint, DataLoader, TransformSequence
logger = logging.getLogger(__name__)
......@@ -58,10 +58,10 @@ class Pasa(pl.LightningModule):
def __init__(
self,
train_loss: torch.nn.Module,
validation_loss: torch.nn.Module | None,
optimizer_type: type[torch.optim.Optimizer],
optimizer_arguments: dict[str, typing.Any],
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
validation_loss: torch.nn.Module | None = None,
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
):
super().__init__()
......@@ -185,10 +185,29 @@ class Pasa(pl.LightningModule):
return x
def on_save_checkpoint(self, checkpoint):
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning to restore your model.
If you saved something with on_save_checkpoint() this is your chance to restore this.
Parameters
----------
checkpoint:
Loaded checkpoint
"""
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint):
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save.
Parameters
----------
checkpoint:
Loaded checkpoint
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
......@@ -289,7 +308,7 @@ class Pasa(pl.LightningModule):
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
images = batch[0]
labels = batch[1]["label"]
names = batch[1]["names"]
names = batch[1]["name"]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
......
......@@ -62,9 +62,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption,
)
@click.option(
"--accelerator",
"-a",
help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
"--device",
"-d",
help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
show_default=True,
required=True,
default="cpu",
......@@ -77,22 +77,14 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
required=True,
cls=ResourceOption,
)
@click.option(
"--grad-cams",
"-g",
help="If set, generate grad cams for each prediction (must use batch of 1)",
is_flag=True,
cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def predict(
output_folder,
model,
datamodule,
batch_size,
accelerator,
device,
weight,
grad_cams,
**_,
) -> None:
"""Predicts Tuberculosis presence (probabilities) on input images."""
......@@ -103,12 +95,11 @@ def predict(
from matplotlib.backends.backend_pdf import PdfPages
from ..engine.device import DeviceManager
from ..engine.predictor import run
from ..utils.plot import relevance_analysis_plot
datamodule = datamodule(
batch_size=batch_size,
)
datamodule.set_chunk_size(batch_size, 1)
logger.info(f"Loading checkpoint from {weight}")
model = model.load_from_checkpoint(weight, strict=False)
......@@ -128,4 +119,4 @@ def predict(
)
pdf.close()
_ = run(model, datamodule, accelerator, output_folder, grad_cams)
_ = run(model, datamodule, DeviceManager(device), output_folder)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment