Skip to content
Snippets Groups Projects
Commit 3fbc9e6a authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmetnation-predict] Save target and mask in hdf5 file

parent c2e0c01c
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -7,12 +7,12 @@ import typing
from mednet.libs.common.data.typing import Sample
from .typing import BinaryPrediction, MultiClassPrediction
from .typing import SegmentationPrediction
def _as_predictions(
samples: typing.Iterable[Sample],
) -> list[BinaryPrediction | MultiClassPrediction]:
) -> list[SegmentationPrediction]:
"""Take a list of separated batch predictions and transforms it into a list
of formal predictions.
......@@ -23,13 +23,12 @@ def _as_predictions(
Returns
-------
list[BinaryPrediction | MultiClassPrediction]
A list of typed predictions that can be saved to disk.
"""
return [(v[1]["name"], v[1]["target"], v[0]) for v in samples]
return [(v[1]["name"], v[1]["target"], v[1]["mask"], v[0]) for v in samples]
def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
def separate(batch: Sample) -> list[SegmentationPrediction]:
"""Separate a collated batch, reconstituting its samples.
This function implements the inverse of
......
......@@ -5,23 +5,12 @@
import typing
import torch
Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
"""Definition of a lightning checkpoint."""
BinaryPrediction: typing.TypeAlias = tuple[str, int, float]
"""The sample name, the target, and the predicted value."""
MultiClassPrediction: typing.TypeAlias = tuple[
str, typing.Sequence[int], typing.Sequence[float]
SegmentationPrediction: typing.TypeAlias = tuple[
str, torch.Tensor, torch.Tensor, torch.Tensor
]
"""The sample name, the target, and the predicted value."""
BinaryPredictionSplit: typing.TypeAlias = typing.Mapping[
str, typing.Sequence[BinaryPrediction]
]
"""A series of predictions for different database splits."""
MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[
str, typing.Sequence[MultiClassPrediction]
]
"""A series of predictions for different database splits."""
......@@ -138,7 +138,8 @@ def experiment(
ctx.invoke(
evaluate,
predictions=predictions_output,
datamodule=datamodule,
predictions_folder=predictions_output,
output_folder=output_folder,
# threshold="validation",
threshold=0.5,
......
......@@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import json
import pathlib
import click
......@@ -18,29 +18,45 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _save_hdf5(
stem: pathlib.Path, prob: PIL.Image.Image, output_folder: pathlib.Path
):
"""Save prediction maps as image in the same format as the test image.
img: PIL.Image.Image,
target: PIL.Image.Image,
mask: PIL.Image.Image,
hdf5_path: pathlib.Path,
) -> None:
"""Save prediction image, target and mask in an hdf5 file.
Parameters
----------
stem
Name of the file without extension on the original dataset.
prob
img
Monochrome Image with prediction maps.
output_folder
Directory in which to store predictions.
target
Target corresponding to the prediction.
mask
Mask corresponding to the prediction.
hdf5_path
File in which to save the data.
"""
fullpath = output_folder / f"{stem}.hdf5"
tqdm.write(f"Saving {fullpath}...")
fullpath.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(fullpath, "w") as f:
data = prob.squeeze(0).numpy()
tqdm.write(f"Saving {hdf5_path}...")
hdf5_path.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(hdf5_path, "w") as f:
f.create_dataset(
"img",
data=img.squeeze(0).numpy(),
compression="gzip",
compression_opts=9,
)
f.create_dataset(
"target",
data=target.squeeze(0).numpy(),
compression="gzip",
compression_opts=9,
)
f.create_dataset(
"array", data=data, compression="gzip", compression_opts=9
"mask",
data=mask.squeeze(0).numpy(),
compression="gzip",
compression_opts=9,
)
......@@ -94,6 +110,17 @@ def predict(
predictions = run(model, datamodule, device_manager)
# Save image data (sample, target, mask) into an hdf5 file
json_predictions = {}
for split_name, split in predictions.items():
pred_paths = []
for sample in split:
_save_hdf5(sample[0], sample[2], output_folder)
hdf5_path = output_folder / f"{sample[0]}.hdf5"
_save_hdf5(sample[3], sample[1], sample[2], hdf5_path)
pred_paths.append([str(sample[0]), str(hdf5_path)])
json_predictions[split_name] = pred_paths
# Save path to hdf5 files into predictions.json
with predictions_file.open("w") as f:
json.dump(json_predictions, f, indent=2)
logger.info(f"Predictions saved to `{str(predictions_file)}`")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment