Skip to content
Snippets Groups Projects
Commit 653ec39e authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[libs.segmentation.scripts.predict] Implement memory-efficient prediction...

[libs.segmentation.scripts.predict] Implement memory-efficient prediction writer; Include sample information on saved HDF5 information
parent fcf9a7d7
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -31,7 +31,7 @@ def run(
| MultiClassPredictionSplit
| None
):
"""Run inference on input data, outputs csv files with predictions.
"""Run inference on input data, output predictions.
Parameters
----------
......
......@@ -46,7 +46,6 @@ def predict(
"""Run inference (generates scores) on all input images, using a pre-trained model."""
import json
import shutil
from mednet.libs.classification.engine.predictor import run
from mednet.libs.common.engine.device import DeviceManager
......@@ -55,24 +54,20 @@ def predict(
save_json_data,
setup_datamodule,
)
from mednet.libs.common.scripts.utils import save_json_with_backup
predictions_file = output_folder / "predictions.json"
predictions_file.parent.mkdir(parents=True, exist_ok=True)
predictions_meta_file = output_folder / "predictions.meta.json"
predictions_meta_file.parent.mkdir(parents=True, exist_ok=True)
setup_datamodule(datamodule, model, batch_size, parallel)
model = load_checkpoint(model, weight)
device_manager = DeviceManager(device)
save_json_data(datamodule, model, predictions_file, device_manager)
save_json_data(datamodule, model, device_manager, predictions_meta_file)
predictions = run(model, datamodule, device_manager)
if predictions_file.exists():
backup = predictions_file.parent / (predictions_file.name + "~")
logger.warning(
f"Output predictions file `{str(predictions_file)}` exists - "
f"backing it up to `{str(backup)}`...",
)
shutil.copy(predictions_file, backup)
predictions_file = output_folder / "predictions.json"
save_json_with_backup(predictions_file, predictions)
with predictions_file.open("w") as f:
json.dump(predictions, f, indent=2)
......
......@@ -7,15 +7,15 @@ import pathlib
import typing
import click
from clapper.click import ResourceOption
import mednet.libs.common.data.datamodule
import mednet.libs.common.models.model
from clapper.logging import setup
from mednet.libs.common.models.model import Model
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def reusable_options(f):
"""Wrap reusable predict script options (for ``experiment``).
def reusable_options(f: typing.Callable):
"""Wrap reusable predict script options for other scripts.
This decorator equips the target function ``f`` with all (reusable)
``predict`` script options.
......@@ -30,6 +30,7 @@ def reusable_options(f):
-------
The decorated version of function ``f``
"""
from clapper.click import ResourceOption
@click.option(
"--output-folder",
......@@ -121,21 +122,24 @@ def reusable_options(f):
def setup_datamodule(
datamodule,
model,
batch_size,
parallel,
datamodule: mednet.libs.common.data.datamodule.ConcatDataModule,
model: mednet.libs.common.models.model.Model,
batch_size: int,
parallel: int,
) -> None: # numpydoc ignore=PR01
"""Configure and set up the datamodule."""
datamodule.batch_size = batch_size
datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms
datamodule.model_transforms = list(model.model_transforms)
datamodule.prepare_data()
datamodule.setup(stage="predict")
def load_checkpoint(model: Model, weight: pathlib.Path) -> Model:
def load_checkpoint(
model: mednet.libs.common.models.model.Model, weight: pathlib.Path
) -> mednet.libs.common.models.model.Model:
"""Load a model checkpoint for prediction.
Parameters
......@@ -162,10 +166,10 @@ def load_checkpoint(model: Model, weight: pathlib.Path) -> Model:
def save_json_data(
datamodule,
model,
output,
datamodule: mednet.libs.common.data.datamodule.ConcatDataModule,
model: mednet.libs.common.models.model.Model,
device_manager,
output_file: pathlib.Path,
) -> None: # numpydoc ignore=PR01
"""Save prediction hyperparameters into a .json file."""
......@@ -187,4 +191,4 @@ def save_json_data(
)
json_data.update(model_summary(model))
json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
save_json_with_backup(output.with_suffix(".meta.json"), json_data)
save_json_with_backup(output_file, json_data)
......@@ -370,7 +370,7 @@ def load_count(
data = numpy.zeros((len(thresholds), 4), dtype=numpy.uint64)
for sample in tqdm(predictions, desc="sample"):
with h5py.File(prediction_path / sample[1], "r") as f:
pred = numpy.array(f.get("img")) # float32
pred = numpy.array(f.get("prediction")) # float32
gt = numpy.array(f.get("target")) # boolean
mask = numpy.array(f.get("mask")) # boolean
data += numpy.array(
......@@ -411,7 +411,7 @@ def load_predictions(
# peak prediction size and number of samples
with h5py.File(prediction_path / predictions[0][1], "r") as f:
elements = numpy.array(f.get("img").shape).prod()
elements = numpy.array(f.get("prediction").shape).prod()
size = len(predictions) * elements
logger.info(
f"Data loading will require ({elements} x {len(predictions)} x 5 =) "
......@@ -424,7 +424,7 @@ def load_predictions(
for i, sample in enumerate(tqdm(predictions, desc="sample")):
with h5py.File(prediction_path / sample[1], "r") as f:
mask = numpy.array(f.get("mask")) # boolean
pred = numpy.array(f.get("img")) # float32
pred = numpy.array(f.get("prediction")) # float32
pred *= mask.astype(numpy.float32)
gt = numpy.array(f.get("target")) # boolean
gt &= mask
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import pathlib
import typing
import h5py
import lightning.pytorch
import lightning.pytorch.callbacks
import torch.utils.data
import tqdm
from mednet.libs.common.engine.device import DeviceManager
logger = logging.getLogger("mednet")
class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter):
"""Write HDF5 files for each sample processed by our model.
Objects of this class can also keep track of samples written to disk and
return a summary list.
Parameters
----------
output_folder
Base directory where to write predictions to.
write_interval
When will this callback be active.
"""
def __init__(
self,
output_folder: pathlib.Path,
write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"] = "batch",
):
super().__init__(write_interval=write_interval)
self.output_folder = output_folder
self._written: list[list[str]] = []
def write_on_batch_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
prediction: typing.Any,
batch_indices: typing.Sequence[int] | None,
batch: typing.Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
"""Write batch predictions to disk.
Parameters
----------
trainer
The trainer being used.
pl_module
The pytorch module.
prediction
The actual predictions to record.
batch_indices
The relative position of samples on the epoch.
batch
The current batch.
batch_idx
Index of the batch overall.
dataloader_idx
Index of the dataloader overall.
"""
for k, p in enumerate(prediction):
stem = pathlib.Path(p[0]).with_suffix(".hdf5")
output_path = self.output_folder / stem
tqdm.tqdm.write(f"`{p[0]}` -> `{str(output_path)}`")
output_path.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(output_path, "w") as f:
f.create_dataset(
"image",
data=batch[0][k].numpy(),
compression="gzip",
compression_opts=9,
)
f.create_dataset(
"prediction",
data=p[3].numpy().squeeze(0),
compression="gzip",
compression_opts=9,
)
f.create_dataset(
"target",
data=(batch[1]["target"][k].squeeze(0).numpy() > 0.5),
compression="gzip",
compression_opts=9,
)
f.create_dataset(
"mask",
data=(batch[1]["mask"][k].squeeze(0).numpy() > 0.5),
compression="gzip",
compression_opts=9,
)
self._written.append([p[0], str(stem)])
def written(self) -> list[list[str]]:
"""Summary of written objects.
Also resets the internal state.
Returns
-------
A list containing a summary of all samples written.
"""
retval = self._written
self._written = []
return retval
def run(
model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager,
output_folder: pathlib.Path,
) -> dict[str, list[list[str]]] | list[list[list[str]]] | list[list[str]] | None:
"""Run inference on input data, output predictions.
Parameters
----------
model
Neural network model (e.g. pasa).
datamodule
The lightning DataModule to use for training **and** validation.
device_manager
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a lightning accelerator setup.
output_folder
Folder where to store HDF5 representations of probability maps.
Returns
-------
A JSON-able representation of sample data stored at ``output_folder``.
For every split (dataloader), a list of samples in the form
``[sample-name, hdf5-path]`` is returned. In the cases where the
``predict_dataloader()`` returns a single loader, we then return a
list. A dictionary is returned in case ``predict_dataloader()`` also
returns a dictionary.
Raises
------
TypeError
If the DataModule's ``predict_dataloader()`` method does not return any
of the types described above.
"""
from lightning.pytorch.loggers.logger import DummyLogger
writer = _HDF5Writer(output_folder)
accelerator, devices = device_manager.lightning_accelerator()
trainer = lightning.pytorch.Trainer(
accelerator=accelerator,
devices=devices,
logger=DummyLogger(),
callbacks=[writer],
)
dataloaders = datamodule.predict_dataloader()
if isinstance(dataloaders, torch.utils.data.DataLoader):
logger.info("Running prediction on a single dataloader...")
trainer.predict(model, dataloaders, return_predictions=False)
return writer.written()
if isinstance(dataloaders, list):
retval_list = []
for k, dataloader in enumerate(dataloaders):
logger.info(f"Running prediction on split `{k}`...")
trainer.predict(model, dataloader, return_predictions=False)
retval_list.append(writer.written())
return retval_list
if isinstance(dataloaders, dict):
retval_dict = {}
for name, dataloader in dataloaders.items():
logger.info(f"Running prediction on `{name}` split...")
trainer.predict(model, dataloader, return_predictions=False)
retval_dict[name] = writer.written()
return retval_dict
if dataloaders is None:
logger.warning("Datamodule did not return any prediction dataloaders!")
return None
# if you get to this point, then the user is returning something that is
# not supported - complain!
raise TypeError(
f"Datamodule returned strangely typed prediction "
f"dataloaders: `{type(dataloaders)}` - Please write code "
f"to support this use-case.",
)
......@@ -15,7 +15,7 @@ from mednet.libs.segmentation.engine.evaluator import SUPPORTED_METRIC_TYPE
logger = setup("mednet")
def _validate_threshold(threshold: float | str, splits: list[str]):
def validate_threshold(threshold: float | str, splits: list[str]):
"""Validate the user threshold selection and returns parsed threshold.
Parameters
......@@ -190,7 +190,7 @@ def evaluate(
json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
save_json_with_backup(evaluation_file.with_suffix(".meta.json"), json_data)
threshold = _validate_threshold(threshold, predict_data)
threshold = validate_threshold(threshold, predict_data)
threshold_list = numpy.arange(
0.0, (1.0 + 1 / steps), 1 / steps, dtype=numpy.float64
)
......
......@@ -2,86 +2,40 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import json
import pathlib
import clapper.click
import clapper.logging
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.common.scripts.predict import reusable_options
from PIL import Image
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _save_hdf5(
img: Image,
target: Image,
mask: Image,
hdf5_path: pathlib.Path,
) -> None:
"""Save prediction image, target and mask in an hdf5 file.
Parameters
----------
img
Monochrome Image with prediction maps.
target
Target corresponding to the prediction.
mask
Mask corresponding to the prediction.
hdf5_path
File in which to save the data.
"""
import h5py
from tqdm import tqdm
tqdm.write(f"Saving {hdf5_path}...")
hdf5_path.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(hdf5_path, "w") as f:
f.create_dataset(
"img",
data=img.squeeze(0),
compression="gzip",
compression_opts=9,
)
f.create_dataset(
"target",
data=target.squeeze(0),
compression="gzip",
compression_opts=9,
)
f.create_dataset(
"mask",
data=mask.squeeze(0),
compression="gzip",
compression_opts=9,
)
import mednet.libs.common.scripts.click
import mednet.libs.common.scripts.predict
logger = clapper.logging.setup(
__name__.split(".")[0], format="%(levelname)s: %(message)s"
)
@click.command(
entry_point_group="mednet.libs.segmentation.config",
cls=ConfigCommand,
cls=mednet.libs.common.scripts.click.ConfigCommand,
epilog="""Examples:
1. Run prediction on an existing DataModule configuration:
.. code:: sh
mednet segmentation predict -vv lwnet drive --weight=path/to/model.ckpt --output=path/to/predictions.json
mednet segmentation predict -vv lwnet drive --weight=path/to/model.ckpt --output-folder=path/to/predictions
2. Enable multi-processing data loading with 6 processes:
.. code:: sh
mednet segmentation predict -vv lwnet drive --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json
mednet segmentation predict -vv lwnet drive --parallel=6 --weight=path/to/model.ckpt --output-folder=path/to/predictions
""",
)
@reusable_options
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
@mednet.libs.common.scripts.predict.reusable_options
@clapper.click.verbosity_option(
logger=logger, cls=clapper.click.ResourceOption, expose_value=False
)
def predict(
output_folder,
model,
......@@ -94,44 +48,22 @@ def predict(
) -> None: # numpydoc ignore=PR01
"""Run inference (generates scores) on all input images, using a pre-trained model."""
import typing
from mednet.libs.classification.engine.predictor import run
from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.scripts.predict import (
load_checkpoint,
save_json_data,
setup_datamodule,
)
from mednet.libs.common.scripts.utils import save_json_with_backup
from mednet.libs.segmentation.engine.predictor import run
predictions_file = output_folder / "predictions.json"
predictions_meta_file = output_folder / "predictions.meta.json"
setup_datamodule(datamodule, model, batch_size, parallel)
model = load_checkpoint(model, weight)
device_manager = DeviceManager(device)
save_json_data(datamodule, model, predictions_file, device_manager)
predictions = run(model, datamodule, device_manager)
# Save image data (sample, target, mask) into an hdf5 file
json_predictions = {}
assert isinstance(
predictions, typing.Mapping
), "predictions must be a dictionary or this program will not work!"
for split_name, split in predictions.items():
pred_paths = []
for sample in split:
hdf5_path = pathlib.Path(f"{sample[0]}").with_suffix(".hdf5")
_save_hdf5(
sample[3].numpy(), # float32
sample[1].numpy() > 0.5, # boolean
sample[2].numpy() > 0.5, # boolean
output_folder / hdf5_path,
)
pred_paths.append([str(sample[0]), str(hdf5_path)])
json_predictions[split_name] = pred_paths
# Save path to hdf5 files into predictions.json
with predictions_file.open("w") as f:
json.dump(json_predictions, f, indent=2)
logger.info(f"Predictions saved to `{str(predictions_file)}`")
save_json_data(datamodule, model, device_manager, predictions_meta_file)
json_predictions = run(model, datamodule, device_manager, output_folder)
save_json_with_backup(output_folder / "predictions.json", json_predictions)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment