From 3fbc9e6acad56d38750210e7213e6c212624d6d7 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 21 May 2024 15:38:06 +0200 Subject: [PATCH] [segmetnation-predict] Save target and mask in hdf5 file --- .../libs/segmentation/models/separate.py | 9 ++- src/mednet/libs/segmentation/models/typing.py | 19 ++---- .../libs/segmentation/scripts/experiment.py | 3 +- .../libs/segmentation/scripts/predict.py | 63 +++++++++++++------ 4 files changed, 55 insertions(+), 39 deletions(-) diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py index 6b9bb5de..716addc6 100644 --- a/src/mednet/libs/segmentation/models/separate.py +++ b/src/mednet/libs/segmentation/models/separate.py @@ -7,12 +7,12 @@ import typing from mednet.libs.common.data.typing import Sample -from .typing import BinaryPrediction, MultiClassPrediction +from .typing import SegmentationPrediction def _as_predictions( samples: typing.Iterable[Sample], -) -> list[BinaryPrediction | MultiClassPrediction]: +) -> list[SegmentationPrediction]: """Take a list of separated batch predictions and transforms it into a list of formal predictions. @@ -23,13 +23,12 @@ def _as_predictions( Returns ------- - list[BinaryPrediction | MultiClassPrediction] A list of typed predictions that can be saved to disk. """ - return [(v[1]["name"], v[1]["target"], v[0]) for v in samples] + return [(v[1]["name"], v[1]["target"], v[1]["mask"], v[0]) for v in samples] -def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: +def separate(batch: Sample) -> list[SegmentationPrediction]: """Separate a collated batch, reconstituting its samples. This function implements the inverse of diff --git a/src/mednet/libs/segmentation/models/typing.py b/src/mednet/libs/segmentation/models/typing.py index 3eb9017c..d452c519 100644 --- a/src/mednet/libs/segmentation/models/typing.py +++ b/src/mednet/libs/segmentation/models/typing.py @@ -5,23 +5,12 @@ import typing +import torch + Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any] """Definition of a lightning checkpoint.""" -BinaryPrediction: typing.TypeAlias = tuple[str, int, float] -"""The sample name, the target, and the predicted value.""" - -MultiClassPrediction: typing.TypeAlias = tuple[ - str, typing.Sequence[int], typing.Sequence[float] +SegmentationPrediction: typing.TypeAlias = tuple[ + str, torch.Tensor, torch.Tensor, torch.Tensor ] """The sample name, the target, and the predicted value.""" - -BinaryPredictionSplit: typing.TypeAlias = typing.Mapping[ - str, typing.Sequence[BinaryPrediction] -] -"""A series of predictions for different database splits.""" - -MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[ - str, typing.Sequence[MultiClassPrediction] -] -"""A series of predictions for different database splits.""" diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py index ccae0eae..75bbfbca 100644 --- a/src/mednet/libs/segmentation/scripts/experiment.py +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -138,7 +138,8 @@ def experiment( ctx.invoke( evaluate, - predictions=predictions_output, + datamodule=datamodule, + predictions_folder=predictions_output, output_folder=output_folder, # threshold="validation", threshold=0.5, diff --git a/src/mednet/libs/segmentation/scripts/predict.py b/src/mednet/libs/segmentation/scripts/predict.py index b7492f53..d0195b7c 100644 --- a/src/mednet/libs/segmentation/scripts/predict.py +++ b/src/mednet/libs/segmentation/scripts/predict.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later - +import json import pathlib import click @@ -18,29 +18,45 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def _save_hdf5( - stem: pathlib.Path, prob: PIL.Image.Image, output_folder: pathlib.Path -): - """Save prediction maps as image in the same format as the test image. + img: PIL.Image.Image, + target: PIL.Image.Image, + mask: PIL.Image.Image, + hdf5_path: pathlib.Path, +) -> None: + """Save prediction image, target and mask in an hdf5 file. Parameters ---------- - stem - Name of the file without extension on the original dataset. - - prob + img Monochrome Image with prediction maps. - - output_folder - Directory in which to store predictions. + target + Target corresponding to the prediction. + mask + Mask corresponding to the prediction. + hdf5_path + File in which to save the data. """ - fullpath = output_folder / f"{stem}.hdf5" - tqdm.write(f"Saving {fullpath}...") - fullpath.parent.mkdir(parents=True, exist_ok=True) - with h5py.File(fullpath, "w") as f: - data = prob.squeeze(0).numpy() + tqdm.write(f"Saving {hdf5_path}...") + hdf5_path.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(hdf5_path, "w") as f: + f.create_dataset( + "img", + data=img.squeeze(0).numpy(), + compression="gzip", + compression_opts=9, + ) + f.create_dataset( + "target", + data=target.squeeze(0).numpy(), + compression="gzip", + compression_opts=9, + ) f.create_dataset( - "array", data=data, compression="gzip", compression_opts=9 + "mask", + data=mask.squeeze(0).numpy(), + compression="gzip", + compression_opts=9, ) @@ -94,6 +110,17 @@ def predict( predictions = run(model, datamodule, device_manager) + # Save image data (sample, target, mask) into an hdf5 file + json_predictions = {} for split_name, split in predictions.items(): + pred_paths = [] for sample in split: - _save_hdf5(sample[0], sample[2], output_folder) + hdf5_path = output_folder / f"{sample[0]}.hdf5" + _save_hdf5(sample[3], sample[1], sample[2], hdf5_path) + pred_paths.append([str(sample[0]), str(hdf5_path)]) + json_predictions[split_name] = pred_paths + + # Save path to hdf5 files into predictions.json + with predictions_file.open("w") as f: + json.dump(json_predictions, f, indent=2) + logger.info(f"Predictions saved to `{str(predictions_file)}`") -- GitLab