Skip to content
Snippets Groups Projects
Commit 3fbc9e6a authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmetnation-predict] Save target and mask in hdf5 file

parent c2e0c01c
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -7,12 +7,12 @@ import typing ...@@ -7,12 +7,12 @@ import typing
from mednet.libs.common.data.typing import Sample from mednet.libs.common.data.typing import Sample
from .typing import BinaryPrediction, MultiClassPrediction from .typing import SegmentationPrediction
def _as_predictions( def _as_predictions(
samples: typing.Iterable[Sample], samples: typing.Iterable[Sample],
) -> list[BinaryPrediction | MultiClassPrediction]: ) -> list[SegmentationPrediction]:
"""Take a list of separated batch predictions and transforms it into a list """Take a list of separated batch predictions and transforms it into a list
of formal predictions. of formal predictions.
...@@ -23,13 +23,12 @@ def _as_predictions( ...@@ -23,13 +23,12 @@ def _as_predictions(
Returns Returns
------- -------
list[BinaryPrediction | MultiClassPrediction]
A list of typed predictions that can be saved to disk. A list of typed predictions that can be saved to disk.
""" """
return [(v[1]["name"], v[1]["target"], v[0]) for v in samples] return [(v[1]["name"], v[1]["target"], v[1]["mask"], v[0]) for v in samples]
def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: def separate(batch: Sample) -> list[SegmentationPrediction]:
"""Separate a collated batch, reconstituting its samples. """Separate a collated batch, reconstituting its samples.
This function implements the inverse of This function implements the inverse of
......
...@@ -5,23 +5,12 @@ ...@@ -5,23 +5,12 @@
import typing import typing
import torch
Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any] Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
"""Definition of a lightning checkpoint.""" """Definition of a lightning checkpoint."""
BinaryPrediction: typing.TypeAlias = tuple[str, int, float] SegmentationPrediction: typing.TypeAlias = tuple[
"""The sample name, the target, and the predicted value.""" str, torch.Tensor, torch.Tensor, torch.Tensor
MultiClassPrediction: typing.TypeAlias = tuple[
str, typing.Sequence[int], typing.Sequence[float]
] ]
"""The sample name, the target, and the predicted value.""" """The sample name, the target, and the predicted value."""
BinaryPredictionSplit: typing.TypeAlias = typing.Mapping[
str, typing.Sequence[BinaryPrediction]
]
"""A series of predictions for different database splits."""
MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[
str, typing.Sequence[MultiClassPrediction]
]
"""A series of predictions for different database splits."""
...@@ -138,7 +138,8 @@ def experiment( ...@@ -138,7 +138,8 @@ def experiment(
ctx.invoke( ctx.invoke(
evaluate, evaluate,
predictions=predictions_output, datamodule=datamodule,
predictions_folder=predictions_output,
output_folder=output_folder, output_folder=output_folder,
# threshold="validation", # threshold="validation",
threshold=0.5, threshold=0.5,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import json
import pathlib import pathlib
import click import click
...@@ -18,29 +18,45 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -18,29 +18,45 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _save_hdf5( def _save_hdf5(
stem: pathlib.Path, prob: PIL.Image.Image, output_folder: pathlib.Path img: PIL.Image.Image,
): target: PIL.Image.Image,
"""Save prediction maps as image in the same format as the test image. mask: PIL.Image.Image,
hdf5_path: pathlib.Path,
) -> None:
"""Save prediction image, target and mask in an hdf5 file.
Parameters Parameters
---------- ----------
stem img
Name of the file without extension on the original dataset.
prob
Monochrome Image with prediction maps. Monochrome Image with prediction maps.
target
output_folder Target corresponding to the prediction.
Directory in which to store predictions. mask
Mask corresponding to the prediction.
hdf5_path
File in which to save the data.
""" """
fullpath = output_folder / f"{stem}.hdf5" tqdm.write(f"Saving {hdf5_path}...")
tqdm.write(f"Saving {fullpath}...") hdf5_path.parent.mkdir(parents=True, exist_ok=True)
fullpath.parent.mkdir(parents=True, exist_ok=True) with h5py.File(hdf5_path, "w") as f:
with h5py.File(fullpath, "w") as f: f.create_dataset(
data = prob.squeeze(0).numpy() "img",
data=img.squeeze(0).numpy(),
compression="gzip",
compression_opts=9,
)
f.create_dataset(
"target",
data=target.squeeze(0).numpy(),
compression="gzip",
compression_opts=9,
)
f.create_dataset( f.create_dataset(
"array", data=data, compression="gzip", compression_opts=9 "mask",
data=mask.squeeze(0).numpy(),
compression="gzip",
compression_opts=9,
) )
...@@ -94,6 +110,17 @@ def predict( ...@@ -94,6 +110,17 @@ def predict(
predictions = run(model, datamodule, device_manager) predictions = run(model, datamodule, device_manager)
# Save image data (sample, target, mask) into an hdf5 file
json_predictions = {}
for split_name, split in predictions.items(): for split_name, split in predictions.items():
pred_paths = []
for sample in split: for sample in split:
_save_hdf5(sample[0], sample[2], output_folder) hdf5_path = output_folder / f"{sample[0]}.hdf5"
_save_hdf5(sample[3], sample[1], sample[2], hdf5_path)
pred_paths.append([str(sample[0]), str(hdf5_path)])
json_predictions[split_name] = pred_paths
# Save path to hdf5 files into predictions.json
with predictions_file.open("w") as f:
json.dump(json_predictions, f, indent=2)
logger.info(f"Predictions saved to `{str(predictions_file)}`")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment