Skip to content
Snippets Groups Projects
Commit e68471ad authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.predictor] Allow prediction to work when batch-size!=0; Discard CSVs...

[engine.predictor] Allow prediction to work when batch-size!=0; Discard CSVs and use JSON for output; Implement lightning logging like in training; Separates batched (collated) data at prediction; Adapts all models to new paradigm; Remove CSVPredictionsWriter; Adapt predict script to all changes
parent b0655852
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -2,10 +2,7 @@ ...@@ -2,10 +2,7 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import csv
import logging import logging
import os
import pathlib
import time import time
import typing import typing
...@@ -380,49 +377,3 @@ class LoggingCallback(lightning.pytorch.Callback): ...@@ -380,49 +377,3 @@ class LoggingCallback(lightning.pytorch.Callback):
{k: self._to_log[k], "step": float(trainer.current_epoch)} {k: self._to_log[k], "step": float(trainer.current_epoch)}
) )
self._to_log = {} self._to_log = {}
class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
"""Lightning callback to write predictions to a file."""
def __init__(
self,
output_dir: str | pathlib.Path,
logfile_fields: typing.Sequence[str],
write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"],
):
super().__init__(write_interval)
self.output_dir = output_dir
self.logfile_fields = logfile_fields
def write_on_epoch_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
predictions: typing.Sequence[typing.Any],
batch_indices: typing.Sequence[typing.Any] | None,
) -> None:
dataloader_names = list(trainer.datamodule.predict_dataloader().keys())
for dataloader_idx, dataloader_name in enumerate(dataloader_names):
logfile = os.path.join(
self.output_dir,
f"{dataloader_name}.csv",
)
os.makedirs(os.path.dirname(logfile), exist_ok=True)
logger.info(f"Saving predictions in {logfile}.")
with open(logfile, "w") as l_f:
logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
logwriter.writeheader()
for prediction in predictions[dataloader_idx]:
logwriter.writerow(
{
"filename": prediction[0],
"likelihood": prediction[1].numpy(),
"ground_truth": prediction[2].numpy(),
}
)
l_f.flush()
...@@ -3,13 +3,10 @@ ...@@ -3,13 +3,10 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import logging import logging
import os import pathlib
import lightning.pytorch import lightning.pytorch
from lightning.pytorch import Trainer
from .callbacks import PredictionsWriter
from .device import DeviceManager from .device import DeviceManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -19,55 +16,64 @@ def run( ...@@ -19,55 +16,64 @@ def run(
model: lightning.pytorch.LightningModule, model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule, datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager, device_manager: DeviceManager,
output_folder: str, output_folder: pathlib.Path,
): ) -> dict[str, list] | list | list[list] | None:
"""Runs inference on input data, outputs csv files with predictions. """Runs inference on input data, outputs csv files with predictions.
Parameters Parameters
--------- ---------
model : :py:class:`torch.nn.Module` model
Neural network model (e.g. pasa). Neural network model (e.g. pasa).
datamodule datamodule
The lightning datamodule to use for training **and** validation The lightning datamodule to use for training **and** validation
device_manager device_manager
An internal device representation, to be used for training and An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device validation. This representation can be converted into a pytorch device
or a torch lightning accelerator setup. or a torch lightning accelerator setup.
output_folder : str output_folder : str
Directory in which the results will be saved. Directory in which the logs will be saved.
grad_cams : bool
If we export grad cams for every prediction (must be used along
a batch size of 1 with the DensenetRS model).
Returns Returns
------- -------
predictions
all_predictions : list A dictionary containing the predictions for each of the input samples
All the predictions associated with filename and ground truth. per dataloader. Keys correspond to the original split names defined at
the loader. If the datamodule's ``predict_dataloader()`` method does
not return a dictionary, then its output is directly passed to the
trainer ``predict()`` method.
""" """
logger.info(f"Output folder: {output_folder}") from .loggers.custom_tensorboard_logger import CustomTensorboardLogger
os.makedirs(output_folder, exist_ok=True)
logfile_fields = ("filename", "likelihood", "ground_truth") log_dir = "logs"
tensorboard_logger = CustomTensorboardLogger(
output_folder,
log_dir,
)
logger.info(
f"Monitor prediction with `tensorboard serve "
f"--logdir={output_folder}/{log_dir}/`. "
f"Then, open a browser on the printed address."
)
accelerator, devices = device_manager.lightning_accelerator() accelerator, devices = device_manager.lightning_accelerator()
trainer = Trainer( trainer = lightning.pytorch.Trainer(
accelerator=accelerator, accelerator=accelerator,
devices=devices, devices=devices,
callbacks=[ logger=tensorboard_logger,
PredictionsWriter(
output_dir=output_folder,
logfile_fields=logfile_fields,
write_interval="epoch",
),
],
) )
all_predictions = trainer.predict(model, datamodule) dataloaders = datamodule.predict_dataloader()
if isinstance(dataloaders, dict):
return all_predictions retval = {}
for name, dataloader in dataloaders.items():
logger.info(f"Running prediction on `{name}` split...")
predictions = trainer.predict(model, dataloader)
retval[name] = [
sample for batch in predictions for sample in batch # type: ignore
]
return retval
# just pass all the loaders to the trainer, let it handle
return trainer.predict(model, datamodule)
...@@ -164,7 +164,7 @@ def run( ...@@ -164,7 +164,7 @@ def run(
log_dir, log_dir,
) )
logger.info( logger.info(
f"Monitor experiment with `tensorboard serve " f"Monitor training with `tensorboard serve "
f"--logdir={output_folder}/{log_dir}/`. " f"--logdir={output_folder}/{log_dir}/`. "
f"Then, open a browser on the printed address." f"Then, open a browser on the printed address."
) )
......
...@@ -14,6 +14,7 @@ import torchvision.models as models ...@@ -14,6 +14,7 @@ import torchvision.models as models
import torchvision.transforms import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .separate import separate
from .transforms import RGB from .transforms import RGB
from .typing import Checkpoint from .typing import Checkpoint
...@@ -205,18 +206,9 @@ class Alexnet(pl.LightningModule): ...@@ -205,18 +206,9 @@ class Alexnet(pl.LightningModule):
return self._validation_loss(outputs, labels.float()) return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
images = batch[0] outputs = self(batch[0])
labels = batch[1]["label"]
names = batch[1]["name"]
outputs = self(images)
probabilities = torch.sigmoid(outputs) probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
return (
names[0],
torch.flatten(probabilities),
torch.flatten(labels),
)
def configure_optimizers(self): def configure_optimizers(self):
return self._optimizer_type( return self._optimizer_type(
......
...@@ -14,6 +14,7 @@ import torchvision.models as models ...@@ -14,6 +14,7 @@ import torchvision.models as models
import torchvision.transforms import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .separate import separate
from .transforms import RGB from .transforms import RGB
from .typing import Checkpoint from .typing import Checkpoint
...@@ -199,18 +200,9 @@ class Densenet(pl.LightningModule): ...@@ -199,18 +200,9 @@ class Densenet(pl.LightningModule):
return self._validation_loss(outputs, labels.float()) return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
images = batch[0] outputs = self(batch[0])
labels = batch[1]["label"]
names = batch[1]["name"]
outputs = self(images)
probabilities = torch.sigmoid(outputs) probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
return (
names[0],
torch.flatten(probabilities),
torch.flatten(labels),
)
def configure_optimizers(self): def configure_optimizers(self):
return self._optimizer_type( return self._optimizer_type(
......
...@@ -8,6 +8,8 @@ import lightning.pytorch as pl ...@@ -8,6 +8,8 @@ import lightning.pytorch as pl
import torch import torch
import torch.nn as nn import torch.nn as nn
from .separate import separate
class LogisticRegression(pl.LightningModule): class LogisticRegression(pl.LightningModule):
"""Logistic regression classifier with a single output. """Logistic regression classifier with a single output.
...@@ -106,18 +108,9 @@ class LogisticRegression(pl.LightningModule): ...@@ -106,18 +108,9 @@ class LogisticRegression(pl.LightningModule):
return {f"extra_validation_loss_{dataloader_idx}": validation_loss} return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0] outputs = self(batch[0])
input = batch[1] probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
output = self(input)
probabilities = torch.sigmoid(output)
# necessary check for HED architecture that uses several outputs
# for loss calculation instead of just the last concatfuse block
if isinstance(output, list):
output = output[-1]
return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
def configure_optimizers(self): def configure_optimizers(self):
return self._optimizer_type( return self._optimizer_type(
......
...@@ -7,6 +7,8 @@ import typing ...@@ -7,6 +7,8 @@ import typing
import lightning.pytorch as pl import lightning.pytorch as pl
import torch import torch
from .separate import separate
class MultiLayerPerceptron(pl.LightningModule): class MultiLayerPerceptron(pl.LightningModule):
"""MLP with a variable number of inputs and hidden neurons (single layer). """MLP with a variable number of inputs and hidden neurons (single layer).
...@@ -111,18 +113,9 @@ class MultiLayerPerceptron(pl.LightningModule): ...@@ -111,18 +113,9 @@ class MultiLayerPerceptron(pl.LightningModule):
return {f"extra_validation_loss_{dataloader_idx}": validation_loss} return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0] outputs = self(batch[0])
input = batch[1] probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
output = self(input)
probabilities = torch.sigmoid(output)
# necessary check for HED architecture that uses several outputs
# for loss calculation instead of just the last concatfuse block
if isinstance(output, list):
output = output[-1]
return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
def configure_optimizers(self): def configure_optimizers(self):
return self._optimizer_type( return self._optimizer_type(
......
...@@ -14,6 +14,7 @@ import torch.utils.data ...@@ -14,6 +14,7 @@ import torch.utils.data
import torchvision.transforms import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .separate import separate
from .transforms import Grayscale from .transforms import Grayscale
from .typing import Checkpoint from .typing import Checkpoint
...@@ -265,18 +266,9 @@ class Pasa(pl.LightningModule): ...@@ -265,18 +266,9 @@ class Pasa(pl.LightningModule):
return self._validation_loss(outputs, labels.float()) return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
images = batch[0] outputs = self(batch[0])
labels = batch[1]["label"]
names = batch[1]["name"]
outputs = self(images)
probabilities = torch.sigmoid(outputs) probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
return (
names[0],
torch.flatten(probabilities),
torch.flatten(labels),
)
def configure_optimizers(self): def configure_optimizers(self):
return self._optimizer_type( return self._optimizer_type(
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Contains the inverse :py:func:`torch.utils.data.default_collate`."""
import torch
from ..data.typing import Sample
def separate(batch: Sample) -> list[Sample]:
"""Separates a collated batch reconstituting its samples.
This function implements the inverse of
:py:func:`torch.utils.data.default_collate`, and can separate, into
samples, batches of data with different attributes. It follows the inverse
path of that function, and implements the following separation algorithms:
* :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer
dimension, via :py:func:`torch.flatten`)
* ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]``
"""
# as of now, this is really simple - to be made more complex upon need.
metadata = [
{key: value[i] for key, value in batch[1].items()}
for i in range(len(batch[0]))
]
return list(zip(torch.flatten(batch[0]), metadata))
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import click import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.click import ConfigCommand, ResourceOption, verbosity_option
...@@ -15,49 +17,61 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -15,49 +17,61 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ConfigCommand, cls=ConfigCommand,
epilog="""Examples: epilog="""Examples:
\b 1. Runs prediction on an existing datamodule configuration:
1. Runs prediction on an existing dataset configuration:
.. code:: sh .. code:: sh
ptbench predict -vv pasa montgomery --weight=path/to/model_final.pth --output-folder=path/to/predictions \b
ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json
2. Enables multi-processing data loading with 6 processes:
.. code:: sh
\b
ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json
""", """,
) )
@click.option( @click.option(
"--output-folder", "--output",
"-o", "-o",
help="Path where to store the predictions (created if does not exist)", help="""Path where to store the JSON predictions for all samples in the
input datamodule (leading directories are created if they do not not
exist).""",
required=True, required=True,
default="results", default="results",
cls=ResourceOption, cls=ResourceOption,
type=click.Path(), type=click.Path(
file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path
),
) )
@click.option( @click.option(
"--model", "--model",
"-m", "-m",
help="A torch.nn.Module instance implementing the network to be evaluated", help="""A lightining module instance implementing the network architecture
(not the weights, necessarily) to be used for prediction.""",
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--datamodule", "--datamodule",
"-d", "-d",
help="A torch.utils.data.dataset.Dataset instance implementing a dataset " help="""A lighting data module that will be asked for prediction data
"to be used for running prediction, possibly including all pre-processing " loaders. Typically, this includes all configured splits in a datamodule,
"pipelines required or, optionally, a dictionary mapping string keys to " however this is not a requirement. A datamodule that returns a single
"torch.utils.data.dataset.Dataset instances. All keys that do not start " dataloader for prediction (wrapped in a dictionary) is acceptable.""",
"with an underscore (_) will be processed.",
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--batch-size", "--batch-size",
"-b", "-b",
help="Number of samples in every batch (this parameter affects memory requirements for the network)", help="""Number of samples in every batch (this parameter affects memory
requirements for the network).""",
required=True, required=True,
show_default=True, show_default=True,
default=1, default=10,
type=click.IntRange(min=1), type=click.IntRange(min=1),
cls=ResourceOption, cls=ResourceOption,
) )
...@@ -73,54 +87,76 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -73,54 +87,76 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.option( @click.option(
"--weight", "--weight",
"-w", "-w",
help="Path or URL to pretrained model file (.ckpt extension)", help="""Path or URL to pretrained model file (`.ckpt` extension),
corresponding to the architecture set with `--model`.""",
required=True,
cls=ResourceOption,
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True),
)
@click.option(
"--parallel",
"-P",
help="""Use multiprocessing for data loading: if set to -1 (default),
disables multiprocessing data loading. Set to 0 to enable as many data
loading instances as processing cores as available in the system. Set to
>= 1 to enable that many multiprocessing instances for data loading.""",
type=click.IntRange(min=-1),
show_default=True,
required=True, required=True,
default=-1,
cls=ResourceOption, cls=ResourceOption,
) )
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def predict( def predict(
output_folder, output,
model, model,
datamodule, datamodule,
batch_size, batch_size,
device, device,
weight, weight,
parallel,
**_, **_,
) -> None: ) -> None:
"""Predicts Tuberculosis presence (probabilities) on input images.""" """Predicts Tuberculosis presence (probabilities) on input images."""
import os import json
import shutil
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
from ..engine.device import DeviceManager from ..engine.device import DeviceManager
from ..engine.predictor import run from ..engine.predictor import run
from ..utils.plot import relevance_analysis_plot
datamodule.set_chunk_size(batch_size, 1) datamodule.set_chunk_size(batch_size, 1)
datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms datamodule.model_transforms = model.model_transforms
datamodule.prepare_data() datamodule.prepare_data()
datamodule.setup(stage="predict") datamodule.setup(stage="predict")
logger.info(f"Loading checkpoint from {weight}") logger.info(f"Loading checkpoint from `{weight}`...")
model = model.load_from_checkpoint(weight, strict=False) model = model.load_from_checkpoint(weight, strict=False)
# Logistic regressor weights predictions = run(model, datamodule, DeviceManager(device), output.parent)
if model.name == "logistic_regression":
logger.info("Logistic regression identified: saving model weights")
for param in model.parameters():
model_weights = np.array(param.data.reshape(-1))
break
filepath = os.path.join(output_folder, "LogReg_Weights.pdf")
logger.info(f"Creating and saving weights plot at {filepath}...")
os.makedirs(os.path.dirname(filepath), exist_ok=True)
pdf = PdfPages(filepath)
pdf.savefig(
relevance_analysis_plot(model_weights, title="LogReg model weights")
)
pdf.close()
_ = run(model, datamodule, DeviceManager(device), output_folder) output.parent.mkdir(parents=True, exist_ok=True)
if output.exists():
backup = output.parent / (output.name + "~")
logger.warning(
f"Output predictions file `{str(output)}` exists - "
f"backing it up to `{str(backup)}`..."
)
shutil.copy(output, backup)
with output.open("w") as f:
flat_predictions: dict[str, list[list]] = {}
# creates a flat representation of predictions that is similar to our
# own JSON split files
for split_name, split_values in predictions.items(): # type: ignore
flat_predictions.setdefault(
split_name,
[
[v[1]["name"], v[1]["label"].item(), v[0].item()]
for v in split_values
],
)
json.dump(flat_predictions, f, indent=2)
logger.info(f"Predictions saved to `{str(output)}`")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment