diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index 2da897a7753c474e2be19edff907e13597b012b5..8669718ea8efc5085c090af1db8b82fd89e16e4f 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -2,10 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import csv import logging -import os -import pathlib import time import typing @@ -380,49 +377,3 @@ class LoggingCallback(lightning.pytorch.Callback): {k: self._to_log[k], "step": float(trainer.current_epoch)} ) self._to_log = {} - - -class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter): - """Lightning callback to write predictions to a file.""" - - def __init__( - self, - output_dir: str | pathlib.Path, - logfile_fields: typing.Sequence[str], - write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"], - ): - super().__init__(write_interval) - self.output_dir = output_dir - self.logfile_fields = logfile_fields - - def write_on_epoch_end( - self, - trainer: lightning.pytorch.Trainer, - pl_module: lightning.pytorch.LightningModule, - predictions: typing.Sequence[typing.Any], - batch_indices: typing.Sequence[typing.Any] | None, - ) -> None: - dataloader_names = list(trainer.datamodule.predict_dataloader().keys()) - - for dataloader_idx, dataloader_name in enumerate(dataloader_names): - logfile = os.path.join( - self.output_dir, - f"{dataloader_name}.csv", - ) - os.makedirs(os.path.dirname(logfile), exist_ok=True) - - logger.info(f"Saving predictions in {logfile}.") - - with open(logfile, "w") as l_f: - logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields) - logwriter.writeheader() - - for prediction in predictions[dataloader_idx]: - logwriter.writerow( - { - "filename": prediction[0], - "likelihood": prediction[1].numpy(), - "ground_truth": prediction[2].numpy(), - } - ) - l_f.flush() diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 6bb6e275d304406182449c77a68a8a8e62406719..dd515789a0218333b3404fa97c3b8a3964de6039 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -3,13 +3,10 @@ # SPDX-License-Identifier: GPL-3.0-or-later import logging -import os +import pathlib import lightning.pytorch -from lightning.pytorch import Trainer - -from .callbacks import PredictionsWriter from .device import DeviceManager logger = logging.getLogger(__name__) @@ -19,55 +16,64 @@ def run( model: lightning.pytorch.LightningModule, datamodule: lightning.pytorch.LightningDataModule, device_manager: DeviceManager, - output_folder: str, -): + output_folder: pathlib.Path, +) -> dict[str, list] | list | list[list] | None: """Runs inference on input data, outputs csv files with predictions. Parameters --------- - model : :py:class:`torch.nn.Module` + model Neural network model (e.g. pasa). - datamodule The lightning datamodule to use for training **and** validation - device_manager An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a torch lightning accelerator setup. - output_folder : str - Directory in which the results will be saved. + Directory in which the logs will be saved. - grad_cams : bool - If we export grad cams for every prediction (must be used along - a batch size of 1 with the DensenetRS model). Returns ------- - - all_predictions : list - All the predictions associated with filename and ground truth. + predictions + A dictionary containing the predictions for each of the input samples + per dataloader. Keys correspond to the original split names defined at + the loader. If the datamodule's ``predict_dataloader()`` method does + not return a dictionary, then its output is directly passed to the + trainer ``predict()`` method. """ - logger.info(f"Output folder: {output_folder}") - os.makedirs(output_folder, exist_ok=True) + from .loggers.custom_tensorboard_logger import CustomTensorboardLogger - logfile_fields = ("filename", "likelihood", "ground_truth") + log_dir = "logs" + tensorboard_logger = CustomTensorboardLogger( + output_folder, + log_dir, + ) + logger.info( + f"Monitor prediction with `tensorboard serve " + f"--logdir={output_folder}/{log_dir}/`. " + f"Then, open a browser on the printed address." + ) accelerator, devices = device_manager.lightning_accelerator() - trainer = Trainer( + trainer = lightning.pytorch.Trainer( accelerator=accelerator, devices=devices, - callbacks=[ - PredictionsWriter( - output_dir=output_folder, - logfile_fields=logfile_fields, - write_interval="epoch", - ), - ], + logger=tensorboard_logger, ) - all_predictions = trainer.predict(model, datamodule) - - return all_predictions + dataloaders = datamodule.predict_dataloader() + if isinstance(dataloaders, dict): + retval = {} + for name, dataloader in dataloaders.items(): + logger.info(f"Running prediction on `{name}` split...") + predictions = trainer.predict(model, dataloader) + retval[name] = [ + sample for batch in predictions for sample in batch # type: ignore + ] + return retval + + # just pass all the loaders to the trainer, let it handle + return trainer.predict(model, datamodule) diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 825ac8cd578f993e2bc874c06ad0cbdeeee0eefd..86f282523dbcfc5edc05fb6ae516e2b05ecea17e 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -164,7 +164,7 @@ def run( log_dir, ) logger.info( - f"Monitor experiment with `tensorboard serve " + f"Monitor training with `tensorboard serve " f"--logdir={output_folder}/{log_dir}/`. " f"Then, open a browser on the printed address." ) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index 9096f6081366ce8983a181caa2d2500f118ef82a..a664c4c6021768319e82fc0a7ff2a1e66cb8a161 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -14,6 +14,7 @@ import torchvision.models as models import torchvision.transforms from ..data.typing import TransformSequence +from .separate import separate from .transforms import RGB from .typing import Checkpoint @@ -205,18 +206,9 @@ class Alexnet(pl.LightningModule): return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - images = batch[0] - labels = batch[1]["label"] - names = batch[1]["name"] - - outputs = self(images) + outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) - - return ( - names[0], - torch.flatten(probabilities), - torch.flatten(labels), - ) + return separate((probabilities, batch[1])) def configure_optimizers(self): return self._optimizer_type( diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 97ebaf78e51e69fc8b732be1e10b69cae2d09231..c416bb0298bd3782e864ce1b0bc26696d3dc68ea 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -14,6 +14,7 @@ import torchvision.models as models import torchvision.transforms from ..data.typing import TransformSequence +from .separate import separate from .transforms import RGB from .typing import Checkpoint @@ -199,18 +200,9 @@ class Densenet(pl.LightningModule): return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - images = batch[0] - labels = batch[1]["label"] - names = batch[1]["name"] - - outputs = self(images) + outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) - - return ( - names[0], - torch.flatten(probabilities), - torch.flatten(labels), - ) + return separate((probabilities, batch[1])) def configure_optimizers(self): return self._optimizer_type( diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py index d0f952bbf9230206c0f66a0c342b9f3cf52b83c0..4e0338f4e0b4ceb6faeef11bedba73a9b2b0204b 100644 --- a/src/ptbench/models/logistic_regression.py +++ b/src/ptbench/models/logistic_regression.py @@ -8,6 +8,8 @@ import lightning.pytorch as pl import torch import torch.nn as nn +from .separate import separate + class LogisticRegression(pl.LightningModule): """Logistic regression classifier with a single output. @@ -106,18 +108,9 @@ class LogisticRegression(pl.LightningModule): return {f"extra_validation_loss_{dataloader_idx}": validation_loss} def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - names = batch[0] - input = batch[1] - - output = self(input) - probabilities = torch.sigmoid(output) - - # necessary check for HED architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(output, list): - output = output[-1] - - return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + outputs = self(batch[0]) + probabilities = torch.sigmoid(outputs) + return separate((probabilities, batch[1])) def configure_optimizers(self): return self._optimizer_type( diff --git a/src/ptbench/models/mlp.py b/src/ptbench/models/mlp.py index 54c7cb69196b109773ad65140ce680d8f16d2bcb..102b384985c1814288e17265e3265015f500aacd 100644 --- a/src/ptbench/models/mlp.py +++ b/src/ptbench/models/mlp.py @@ -7,6 +7,8 @@ import typing import lightning.pytorch as pl import torch +from .separate import separate + class MultiLayerPerceptron(pl.LightningModule): """MLP with a variable number of inputs and hidden neurons (single layer). @@ -111,18 +113,9 @@ class MultiLayerPerceptron(pl.LightningModule): return {f"extra_validation_loss_{dataloader_idx}": validation_loss} def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - names = batch[0] - input = batch[1] - - output = self(input) - probabilities = torch.sigmoid(output) - - # necessary check for HED architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(output, list): - output = output[-1] - - return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + outputs = self(batch[0]) + probabilities = torch.sigmoid(outputs) + return separate((probabilities, batch[1])) def configure_optimizers(self): return self._optimizer_type( diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 5d6e20b4dd25c8ba6d57e3c681cdc2514af616ee..c89a6d328a229dd7839956e2105e42541c1df0ea 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -14,6 +14,7 @@ import torch.utils.data import torchvision.transforms from ..data.typing import TransformSequence +from .separate import separate from .transforms import Grayscale from .typing import Checkpoint @@ -265,18 +266,9 @@ class Pasa(pl.LightningModule): return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - images = batch[0] - labels = batch[1]["label"] - names = batch[1]["name"] - - outputs = self(images) + outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) - - return ( - names[0], - torch.flatten(probabilities), - torch.flatten(labels), - ) + return separate((probabilities, batch[1])) def configure_optimizers(self): return self._optimizer_type( diff --git a/src/ptbench/models/separate.py b/src/ptbench/models/separate.py new file mode 100644 index 0000000000000000000000000000000000000000..522d5b94d2313a34c23cb6322c1d986dc3631217 --- /dev/null +++ b/src/ptbench/models/separate.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Contains the inverse :py:func:`torch.utils.data.default_collate`.""" + +import torch + +from ..data.typing import Sample + + +def separate(batch: Sample) -> list[Sample]: + """Separates a collated batch reconstituting its samples. + + This function implements the inverse of + :py:func:`torch.utils.data.default_collate`, and can separate, into + samples, batches of data with different attributes. It follows the inverse + path of that function, and implements the following separation algorithms: + + * :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer + dimension, via :py:func:`torch.flatten`) + * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]`` + """ + + # as of now, this is really simple - to be made more complex upon need. + metadata = [ + {key: value[i] for key, value in batch[1].items()} + for i in range(len(batch[0])) + ] + return list(zip(torch.flatten(batch[0]), metadata)) diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index b5ed123c2da9df3b1c485b999ac5eb9bfa04c9ee..551422fde176a9095ec3ac7dd3f69c6e174161aa 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import pathlib + import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option @@ -15,49 +17,61 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -\b - 1. Runs prediction on an existing dataset configuration: + 1. Runs prediction on an existing datamodule configuration: .. code:: sh - ptbench predict -vv pasa montgomery --weight=path/to/model_final.pth --output-folder=path/to/predictions + \b + ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json + + 2. Enables multi-processing data loading with 6 processes: + + .. code:: sh + + \b + ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json """, ) @click.option( - "--output-folder", + "--output", "-o", - help="Path where to store the predictions (created if does not exist)", + help="""Path where to store the JSON predictions for all samples in the + input datamodule (leading directories are created if they do not not + exist).""", required=True, default="results", cls=ResourceOption, - type=click.Path(), + type=click.Path( + file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path + ), ) @click.option( "--model", "-m", - help="A torch.nn.Module instance implementing the network to be evaluated", + help="""A lightining module instance implementing the network architecture + (not the weights, necessarily) to be used for prediction.""", required=True, cls=ResourceOption, ) @click.option( "--datamodule", "-d", - help="A torch.utils.data.dataset.Dataset instance implementing a dataset " - "to be used for running prediction, possibly including all pre-processing " - "pipelines required or, optionally, a dictionary mapping string keys to " - "torch.utils.data.dataset.Dataset instances. All keys that do not start " - "with an underscore (_) will be processed.", + help="""A lighting data module that will be asked for prediction data + loaders. Typically, this includes all configured splits in a datamodule, + however this is not a requirement. A datamodule that returns a single + dataloader for prediction (wrapped in a dictionary) is acceptable.""", required=True, cls=ResourceOption, ) @click.option( "--batch-size", "-b", - help="Number of samples in every batch (this parameter affects memory requirements for the network)", + help="""Number of samples in every batch (this parameter affects memory + requirements for the network).""", required=True, show_default=True, - default=1, + default=10, type=click.IntRange(min=1), cls=ResourceOption, ) @@ -73,54 +87,76 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--weight", "-w", - help="Path or URL to pretrained model file (.ckpt extension)", + help="""Path or URL to pretrained model file (`.ckpt` extension), + corresponding to the architecture set with `--model`.""", + required=True, + cls=ResourceOption, + type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), +) +@click.option( + "--parallel", + "-P", + help="""Use multiprocessing for data loading: if set to -1 (default), + disables multiprocessing data loading. Set to 0 to enable as many data + loading instances as processing cores as available in the system. Set to + >= 1 to enable that many multiprocessing instances for data loading.""", + type=click.IntRange(min=-1), + show_default=True, required=True, + default=-1, cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def predict( - output_folder, + output, model, datamodule, batch_size, device, weight, + parallel, **_, ) -> None: """Predicts Tuberculosis presence (probabilities) on input images.""" - import os - - import numpy as np - - from matplotlib.backends.backend_pdf import PdfPages + import json + import shutil from ..engine.device import DeviceManager from ..engine.predictor import run - from ..utils.plot import relevance_analysis_plot datamodule.set_chunk_size(batch_size, 1) + datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms datamodule.prepare_data() datamodule.setup(stage="predict") - logger.info(f"Loading checkpoint from {weight}") + logger.info(f"Loading checkpoint from `{weight}`...") model = model.load_from_checkpoint(weight, strict=False) - # Logistic regressor weights - if model.name == "logistic_regression": - logger.info("Logistic regression identified: saving model weights") - for param in model.parameters(): - model_weights = np.array(param.data.reshape(-1)) - break - filepath = os.path.join(output_folder, "LogReg_Weights.pdf") - logger.info(f"Creating and saving weights plot at {filepath}...") - os.makedirs(os.path.dirname(filepath), exist_ok=True) - pdf = PdfPages(filepath) - pdf.savefig( - relevance_analysis_plot(model_weights, title="LogReg model weights") - ) - pdf.close() + predictions = run(model, datamodule, DeviceManager(device), output.parent) - _ = run(model, datamodule, DeviceManager(device), output_folder) + output.parent.mkdir(parents=True, exist_ok=True) + if output.exists(): + backup = output.parent / (output.name + "~") + logger.warning( + f"Output predictions file `{str(output)}` exists - " + f"backing it up to `{str(backup)}`..." + ) + shutil.copy(output, backup) + + with output.open("w") as f: + flat_predictions: dict[str, list[list]] = {} + # creates a flat representation of predictions that is similar to our + # own JSON split files + for split_name, split_values in predictions.items(): # type: ignore + flat_predictions.setdefault( + split_name, + [ + [v[1]["name"], v[1]["label"].item(), v[0].item()] + for v in split_values + ], + ) + json.dump(flat_predictions, f, indent=2) + logger.info(f"Predictions saved to `{str(output)}`")