From 653ec39ee00e33ef9dfa63bdda8e713e778f7b00 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Tue, 25 Jun 2024 15:51:14 +0200 Subject: [PATCH] [libs.segmentation.scripts.predict] Implement memory-efficient prediction writer; Include sample information on saved HDF5 information --- .../libs/classification/engine/predictor.py | 2 +- .../libs/classification/scripts/predict.py | 17 +- src/mednet/libs/common/scripts/predict.py | 32 +-- .../libs/segmentation/engine/evaluator.py | 6 +- .../libs/segmentation/engine/predictor.py | 200 ++++++++++++++++++ .../libs/segmentation/scripts/evaluate.py | 4 +- .../libs/segmentation/scripts/predict.py | 114 ++-------- 7 files changed, 253 insertions(+), 122 deletions(-) create mode 100644 src/mednet/libs/segmentation/engine/predictor.py diff --git a/src/mednet/libs/classification/engine/predictor.py b/src/mednet/libs/classification/engine/predictor.py index 1701a23e..771d918f 100644 --- a/src/mednet/libs/classification/engine/predictor.py +++ b/src/mednet/libs/classification/engine/predictor.py @@ -31,7 +31,7 @@ def run( | MultiClassPredictionSplit | None ): - """Run inference on input data, outputs csv files with predictions. + """Run inference on input data, output predictions. Parameters ---------- diff --git a/src/mednet/libs/classification/scripts/predict.py b/src/mednet/libs/classification/scripts/predict.py index a6f45a11..0a0f4451 100644 --- a/src/mednet/libs/classification/scripts/predict.py +++ b/src/mednet/libs/classification/scripts/predict.py @@ -46,7 +46,6 @@ def predict( """Run inference (generates scores) on all input images, using a pre-trained model.""" import json - import shutil from mednet.libs.classification.engine.predictor import run from mednet.libs.common.engine.device import DeviceManager @@ -55,24 +54,20 @@ def predict( save_json_data, setup_datamodule, ) + from mednet.libs.common.scripts.utils import save_json_with_backup - predictions_file = output_folder / "predictions.json" - predictions_file.parent.mkdir(parents=True, exist_ok=True) + predictions_meta_file = output_folder / "predictions.meta.json" + predictions_meta_file.parent.mkdir(parents=True, exist_ok=True) setup_datamodule(datamodule, model, batch_size, parallel) model = load_checkpoint(model, weight) device_manager = DeviceManager(device) - save_json_data(datamodule, model, predictions_file, device_manager) + save_json_data(datamodule, model, device_manager, predictions_meta_file) predictions = run(model, datamodule, device_manager) - if predictions_file.exists(): - backup = predictions_file.parent / (predictions_file.name + "~") - logger.warning( - f"Output predictions file `{str(predictions_file)}` exists - " - f"backing it up to `{str(backup)}`...", - ) - shutil.copy(predictions_file, backup) + predictions_file = output_folder / "predictions.json" + save_json_with_backup(predictions_file, predictions) with predictions_file.open("w") as f: json.dump(predictions, f, indent=2) diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py index d38c3841..422b9e56 100644 --- a/src/mednet/libs/common/scripts/predict.py +++ b/src/mednet/libs/common/scripts/predict.py @@ -7,15 +7,15 @@ import pathlib import typing import click -from clapper.click import ResourceOption +import mednet.libs.common.data.datamodule +import mednet.libs.common.models.model from clapper.logging import setup -from mednet.libs.common.models.model import Model logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") -def reusable_options(f): - """Wrap reusable predict script options (for ``experiment``). +def reusable_options(f: typing.Callable): + """Wrap reusable predict script options for other scripts. This decorator equips the target function ``f`` with all (reusable) ``predict`` script options. @@ -30,6 +30,7 @@ def reusable_options(f): ------- The decorated version of function ``f`` """ + from clapper.click import ResourceOption @click.option( "--output-folder", @@ -121,21 +122,24 @@ def reusable_options(f): def setup_datamodule( - datamodule, - model, - batch_size, - parallel, + datamodule: mednet.libs.common.data.datamodule.ConcatDataModule, + model: mednet.libs.common.models.model.Model, + batch_size: int, + parallel: int, ) -> None: # numpydoc ignore=PR01 """Configure and set up the datamodule.""" + datamodule.batch_size = batch_size datamodule.parallel = parallel - datamodule.model_transforms = model.model_transforms + datamodule.model_transforms = list(model.model_transforms) datamodule.prepare_data() datamodule.setup(stage="predict") -def load_checkpoint(model: Model, weight: pathlib.Path) -> Model: +def load_checkpoint( + model: mednet.libs.common.models.model.Model, weight: pathlib.Path +) -> mednet.libs.common.models.model.Model: """Load a model checkpoint for prediction. Parameters @@ -162,10 +166,10 @@ def load_checkpoint(model: Model, weight: pathlib.Path) -> Model: def save_json_data( - datamodule, - model, - output, + datamodule: mednet.libs.common.data.datamodule.ConcatDataModule, + model: mednet.libs.common.models.model.Model, device_manager, + output_file: pathlib.Path, ) -> None: # numpydoc ignore=PR01 """Save prediction hyperparameters into a .json file.""" @@ -187,4 +191,4 @@ def save_json_data( ) json_data.update(model_summary(model)) json_data = {k.replace("_", "-"): v for k, v in json_data.items()} - save_json_with_backup(output.with_suffix(".meta.json"), json_data) + save_json_with_backup(output_file, json_data) diff --git a/src/mednet/libs/segmentation/engine/evaluator.py b/src/mednet/libs/segmentation/engine/evaluator.py index 82ec4bbb..46b2d9ff 100644 --- a/src/mednet/libs/segmentation/engine/evaluator.py +++ b/src/mednet/libs/segmentation/engine/evaluator.py @@ -370,7 +370,7 @@ def load_count( data = numpy.zeros((len(thresholds), 4), dtype=numpy.uint64) for sample in tqdm(predictions, desc="sample"): with h5py.File(prediction_path / sample[1], "r") as f: - pred = numpy.array(f.get("img")) # float32 + pred = numpy.array(f.get("prediction")) # float32 gt = numpy.array(f.get("target")) # boolean mask = numpy.array(f.get("mask")) # boolean data += numpy.array( @@ -411,7 +411,7 @@ def load_predictions( # peak prediction size and number of samples with h5py.File(prediction_path / predictions[0][1], "r") as f: - elements = numpy.array(f.get("img").shape).prod() + elements = numpy.array(f.get("prediction").shape).prod() size = len(predictions) * elements logger.info( f"Data loading will require ({elements} x {len(predictions)} x 5 =) " @@ -424,7 +424,7 @@ def load_predictions( for i, sample in enumerate(tqdm(predictions, desc="sample")): with h5py.File(prediction_path / sample[1], "r") as f: mask = numpy.array(f.get("mask")) # boolean - pred = numpy.array(f.get("img")) # float32 + pred = numpy.array(f.get("prediction")) # float32 pred *= mask.astype(numpy.float32) gt = numpy.array(f.get("target")) # boolean gt &= mask diff --git a/src/mednet/libs/segmentation/engine/predictor.py b/src/mednet/libs/segmentation/engine/predictor.py new file mode 100644 index 00000000..8371c9a4 --- /dev/null +++ b/src/mednet/libs/segmentation/engine/predictor.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import pathlib +import typing + +import h5py +import lightning.pytorch +import lightning.pytorch.callbacks +import torch.utils.data +import tqdm +from mednet.libs.common.engine.device import DeviceManager + +logger = logging.getLogger("mednet") + + +class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter): + """Write HDF5 files for each sample processed by our model. + + Objects of this class can also keep track of samples written to disk and + return a summary list. + + Parameters + ---------- + output_folder + Base directory where to write predictions to. + + write_interval + When will this callback be active. + """ + + def __init__( + self, + output_folder: pathlib.Path, + write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"] = "batch", + ): + super().__init__(write_interval=write_interval) + self.output_folder = output_folder + self._written: list[list[str]] = [] + + def write_on_batch_end( + self, + trainer: lightning.pytorch.Trainer, + pl_module: lightning.pytorch.LightningModule, + prediction: typing.Any, + batch_indices: typing.Sequence[int] | None, + batch: typing.Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + """Write batch predictions to disk. + + Parameters + ---------- + trainer + The trainer being used. + pl_module + The pytorch module. + prediction + The actual predictions to record. + batch_indices + The relative position of samples on the epoch. + batch + The current batch. + batch_idx + Index of the batch overall. + dataloader_idx + Index of the dataloader overall. + """ + for k, p in enumerate(prediction): + stem = pathlib.Path(p[0]).with_suffix(".hdf5") + output_path = self.output_folder / stem + tqdm.tqdm.write(f"`{p[0]}` -> `{str(output_path)}`") + output_path.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(output_path, "w") as f: + f.create_dataset( + "image", + data=batch[0][k].numpy(), + compression="gzip", + compression_opts=9, + ) + f.create_dataset( + "prediction", + data=p[3].numpy().squeeze(0), + compression="gzip", + compression_opts=9, + ) + f.create_dataset( + "target", + data=(batch[1]["target"][k].squeeze(0).numpy() > 0.5), + compression="gzip", + compression_opts=9, + ) + f.create_dataset( + "mask", + data=(batch[1]["mask"][k].squeeze(0).numpy() > 0.5), + compression="gzip", + compression_opts=9, + ) + self._written.append([p[0], str(stem)]) + + def written(self) -> list[list[str]]: + """Summary of written objects. + + Also resets the internal state. + + Returns + ------- + A list containing a summary of all samples written. + """ + retval = self._written + self._written = [] + return retval + + +def run( + model: lightning.pytorch.LightningModule, + datamodule: lightning.pytorch.LightningDataModule, + device_manager: DeviceManager, + output_folder: pathlib.Path, +) -> dict[str, list[list[str]]] | list[list[list[str]]] | list[list[str]] | None: + """Run inference on input data, output predictions. + + Parameters + ---------- + model + Neural network model (e.g. pasa). + datamodule + The lightning DataModule to use for training **and** validation. + device_manager + An internal device representation, to be used for training and + validation. This representation can be converted into a pytorch device + or a lightning accelerator setup. + output_folder + Folder where to store HDF5 representations of probability maps. + + Returns + ------- + A JSON-able representation of sample data stored at ``output_folder``. + For every split (dataloader), a list of samples in the form + ``[sample-name, hdf5-path]`` is returned. In the cases where the + ``predict_dataloader()`` returns a single loader, we then return a + list. A dictionary is returned in case ``predict_dataloader()`` also + returns a dictionary. + + Raises + ------ + TypeError + If the DataModule's ``predict_dataloader()`` method does not return any + of the types described above. + """ + + from lightning.pytorch.loggers.logger import DummyLogger + + writer = _HDF5Writer(output_folder) + + accelerator, devices = device_manager.lightning_accelerator() + trainer = lightning.pytorch.Trainer( + accelerator=accelerator, + devices=devices, + logger=DummyLogger(), + callbacks=[writer], + ) + + dataloaders = datamodule.predict_dataloader() + + if isinstance(dataloaders, torch.utils.data.DataLoader): + logger.info("Running prediction on a single dataloader...") + trainer.predict(model, dataloaders, return_predictions=False) + return writer.written() + + if isinstance(dataloaders, list): + retval_list = [] + for k, dataloader in enumerate(dataloaders): + logger.info(f"Running prediction on split `{k}`...") + trainer.predict(model, dataloader, return_predictions=False) + retval_list.append(writer.written()) + return retval_list + + if isinstance(dataloaders, dict): + retval_dict = {} + for name, dataloader in dataloaders.items(): + logger.info(f"Running prediction on `{name}` split...") + trainer.predict(model, dataloader, return_predictions=False) + retval_dict[name] = writer.written() + return retval_dict + + if dataloaders is None: + logger.warning("Datamodule did not return any prediction dataloaders!") + return None + + # if you get to this point, then the user is returning something that is + # not supported - complain! + raise TypeError( + f"Datamodule returned strangely typed prediction " + f"dataloaders: `{type(dataloaders)}` - Please write code " + f"to support this use-case.", + ) diff --git a/src/mednet/libs/segmentation/scripts/evaluate.py b/src/mednet/libs/segmentation/scripts/evaluate.py index dbe12bb3..191cde38 100644 --- a/src/mednet/libs/segmentation/scripts/evaluate.py +++ b/src/mednet/libs/segmentation/scripts/evaluate.py @@ -15,7 +15,7 @@ from mednet.libs.segmentation.engine.evaluator import SUPPORTED_METRIC_TYPE logger = setup("mednet") -def _validate_threshold(threshold: float | str, splits: list[str]): +def validate_threshold(threshold: float | str, splits: list[str]): """Validate the user threshold selection and returns parsed threshold. Parameters @@ -190,7 +190,7 @@ def evaluate( json_data = {k.replace("_", "-"): v for k, v in json_data.items()} save_json_with_backup(evaluation_file.with_suffix(".meta.json"), json_data) - threshold = _validate_threshold(threshold, predict_data) + threshold = validate_threshold(threshold, predict_data) threshold_list = numpy.arange( 0.0, (1.0 + 1 / steps), 1 / steps, dtype=numpy.float64 ) diff --git a/src/mednet/libs/segmentation/scripts/predict.py b/src/mednet/libs/segmentation/scripts/predict.py index 9c5c3d44..84c50d47 100644 --- a/src/mednet/libs/segmentation/scripts/predict.py +++ b/src/mednet/libs/segmentation/scripts/predict.py @@ -2,86 +2,40 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import json -import pathlib - +import clapper.click +import clapper.logging import click -from clapper.click import ResourceOption, verbosity_option -from clapper.logging import setup -from mednet.libs.common.scripts.click import ConfigCommand -from mednet.libs.common.scripts.predict import reusable_options -from PIL import Image - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -def _save_hdf5( - img: Image, - target: Image, - mask: Image, - hdf5_path: pathlib.Path, -) -> None: - """Save prediction image, target and mask in an hdf5 file. - - Parameters - ---------- - img - Monochrome Image with prediction maps. - target - Target corresponding to the prediction. - mask - Mask corresponding to the prediction. - hdf5_path - File in which to save the data. - """ - - import h5py - from tqdm import tqdm - - tqdm.write(f"Saving {hdf5_path}...") - hdf5_path.parent.mkdir(parents=True, exist_ok=True) - with h5py.File(hdf5_path, "w") as f: - f.create_dataset( - "img", - data=img.squeeze(0), - compression="gzip", - compression_opts=9, - ) - f.create_dataset( - "target", - data=target.squeeze(0), - compression="gzip", - compression_opts=9, - ) - f.create_dataset( - "mask", - data=mask.squeeze(0), - compression="gzip", - compression_opts=9, - ) +import mednet.libs.common.scripts.click +import mednet.libs.common.scripts.predict + +logger = clapper.logging.setup( + __name__.split(".")[0], format="%(levelname)s: %(message)s" +) @click.command( entry_point_group="mednet.libs.segmentation.config", - cls=ConfigCommand, + cls=mednet.libs.common.scripts.click.ConfigCommand, epilog="""Examples: 1. Run prediction on an existing DataModule configuration: .. code:: sh - mednet segmentation predict -vv lwnet drive --weight=path/to/model.ckpt --output=path/to/predictions.json + mednet segmentation predict -vv lwnet drive --weight=path/to/model.ckpt --output-folder=path/to/predictions 2. Enable multi-processing data loading with 6 processes: .. code:: sh - mednet segmentation predict -vv lwnet drive --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json + mednet segmentation predict -vv lwnet drive --parallel=6 --weight=path/to/model.ckpt --output-folder=path/to/predictions """, ) -@reusable_options -@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +@mednet.libs.common.scripts.predict.reusable_options +@clapper.click.verbosity_option( + logger=logger, cls=clapper.click.ResourceOption, expose_value=False +) def predict( output_folder, model, @@ -94,44 +48,22 @@ def predict( ) -> None: # numpydoc ignore=PR01 """Run inference (generates scores) on all input images, using a pre-trained model.""" - import typing - - from mednet.libs.classification.engine.predictor import run from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.scripts.predict import ( load_checkpoint, save_json_data, setup_datamodule, ) + from mednet.libs.common.scripts.utils import save_json_with_backup + from mednet.libs.segmentation.engine.predictor import run - predictions_file = output_folder / "predictions.json" + predictions_meta_file = output_folder / "predictions.meta.json" setup_datamodule(datamodule, model, batch_size, parallel) model = load_checkpoint(model, weight) device_manager = DeviceManager(device) - save_json_data(datamodule, model, predictions_file, device_manager) - - predictions = run(model, datamodule, device_manager) - - # Save image data (sample, target, mask) into an hdf5 file - json_predictions = {} - assert isinstance( - predictions, typing.Mapping - ), "predictions must be a dictionary or this program will not work!" - for split_name, split in predictions.items(): - pred_paths = [] - for sample in split: - hdf5_path = pathlib.Path(f"{sample[0]}").with_suffix(".hdf5") - _save_hdf5( - sample[3].numpy(), # float32 - sample[1].numpy() > 0.5, # boolean - sample[2].numpy() > 0.5, # boolean - output_folder / hdf5_path, - ) - pred_paths.append([str(sample[0]), str(hdf5_path)]) - json_predictions[split_name] = pred_paths - - # Save path to hdf5 files into predictions.json - with predictions_file.open("w") as f: - json.dump(json_predictions, f, indent=2) - logger.info(f"Predictions saved to `{str(predictions_file)}`") + save_json_data(datamodule, model, device_manager, predictions_meta_file) + + json_predictions = run(model, datamodule, device_manager, output_folder) + + save_json_with_backup(output_folder / "predictions.json", json_predictions) -- GitLab