From 3fbc9e6acad56d38750210e7213e6c212624d6d7 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 21 May 2024 15:38:06 +0200
Subject: [PATCH] [segmetnation-predict] Save target and mask in hdf5 file

---
 .../libs/segmentation/models/separate.py      |  9 ++-
 src/mednet/libs/segmentation/models/typing.py | 19 ++----
 .../libs/segmentation/scripts/experiment.py   |  3 +-
 .../libs/segmentation/scripts/predict.py      | 63 +++++++++++++------
 4 files changed, 55 insertions(+), 39 deletions(-)

diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py
index 6b9bb5de..716addc6 100644
--- a/src/mednet/libs/segmentation/models/separate.py
+++ b/src/mednet/libs/segmentation/models/separate.py
@@ -7,12 +7,12 @@ import typing
 
 from mednet.libs.common.data.typing import Sample
 
-from .typing import BinaryPrediction, MultiClassPrediction
+from .typing import SegmentationPrediction
 
 
 def _as_predictions(
     samples: typing.Iterable[Sample],
-) -> list[BinaryPrediction | MultiClassPrediction]:
+) -> list[SegmentationPrediction]:
     """Take a list of separated batch predictions and transforms it into a list
     of formal predictions.
 
@@ -23,13 +23,12 @@ def _as_predictions(
 
     Returns
     -------
-    list[BinaryPrediction | MultiClassPrediction]
         A list of typed predictions that can be saved to disk.
     """
-    return [(v[1]["name"], v[1]["target"], v[0]) for v in samples]
+    return [(v[1]["name"], v[1]["target"], v[1]["mask"], v[0]) for v in samples]
 
 
-def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
+def separate(batch: Sample) -> list[SegmentationPrediction]:
     """Separate a collated batch, reconstituting its samples.
 
     This function implements the inverse of
diff --git a/src/mednet/libs/segmentation/models/typing.py b/src/mednet/libs/segmentation/models/typing.py
index 3eb9017c..d452c519 100644
--- a/src/mednet/libs/segmentation/models/typing.py
+++ b/src/mednet/libs/segmentation/models/typing.py
@@ -5,23 +5,12 @@
 
 import typing
 
+import torch
+
 Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
 """Definition of a lightning checkpoint."""
 
-BinaryPrediction: typing.TypeAlias = tuple[str, int, float]
-"""The sample name, the target, and the predicted value."""
-
-MultiClassPrediction: typing.TypeAlias = tuple[
-    str, typing.Sequence[int], typing.Sequence[float]
+SegmentationPrediction: typing.TypeAlias = tuple[
+    str, torch.Tensor, torch.Tensor, torch.Tensor
 ]
 """The sample name, the target, and the predicted value."""
-
-BinaryPredictionSplit: typing.TypeAlias = typing.Mapping[
-    str, typing.Sequence[BinaryPrediction]
-]
-"""A series of predictions for different database splits."""
-
-MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[
-    str, typing.Sequence[MultiClassPrediction]
-]
-"""A series of predictions for different database splits."""
diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py
index ccae0eae..75bbfbca 100644
--- a/src/mednet/libs/segmentation/scripts/experiment.py
+++ b/src/mednet/libs/segmentation/scripts/experiment.py
@@ -138,7 +138,8 @@ def experiment(
 
     ctx.invoke(
         evaluate,
-        predictions=predictions_output,
+        datamodule=datamodule,
+        predictions_folder=predictions_output,
         output_folder=output_folder,
         # threshold="validation",
         threshold=0.5,
diff --git a/src/mednet/libs/segmentation/scripts/predict.py b/src/mednet/libs/segmentation/scripts/predict.py
index b7492f53..d0195b7c 100644
--- a/src/mednet/libs/segmentation/scripts/predict.py
+++ b/src/mednet/libs/segmentation/scripts/predict.py
@@ -2,7 +2,7 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-
+import json
 import pathlib
 
 import click
@@ -18,29 +18,45 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
 def _save_hdf5(
-    stem: pathlib.Path, prob: PIL.Image.Image, output_folder: pathlib.Path
-):
-    """Save prediction maps as image in the same format as the test image.
+    img: PIL.Image.Image,
+    target: PIL.Image.Image,
+    mask: PIL.Image.Image,
+    hdf5_path: pathlib.Path,
+) -> None:
+    """Save prediction image, target and mask in an hdf5 file.
 
     Parameters
     ----------
-    stem
-        Name of the file without extension on the original dataset.
-
-    prob
+    img
         Monochrome Image with prediction maps.
-
-    output_folder
-        Directory in which to store predictions.
+    target
+        Target corresponding to the prediction.
+    mask
+        Mask corresponding to the prediction.
+    hdf5_path
+        File in which to save the data.
     """
 
-    fullpath = output_folder / f"{stem}.hdf5"
-    tqdm.write(f"Saving {fullpath}...")
-    fullpath.parent.mkdir(parents=True, exist_ok=True)
-    with h5py.File(fullpath, "w") as f:
-        data = prob.squeeze(0).numpy()
+    tqdm.write(f"Saving {hdf5_path}...")
+    hdf5_path.parent.mkdir(parents=True, exist_ok=True)
+    with h5py.File(hdf5_path, "w") as f:
+        f.create_dataset(
+            "img",
+            data=img.squeeze(0).numpy(),
+            compression="gzip",
+            compression_opts=9,
+        )
+        f.create_dataset(
+            "target",
+            data=target.squeeze(0).numpy(),
+            compression="gzip",
+            compression_opts=9,
+        )
         f.create_dataset(
-            "array", data=data, compression="gzip", compression_opts=9
+            "mask",
+            data=mask.squeeze(0).numpy(),
+            compression="gzip",
+            compression_opts=9,
         )
 
 
@@ -94,6 +110,17 @@ def predict(
 
     predictions = run(model, datamodule, device_manager)
 
+    # Save image data (sample, target, mask) into an hdf5 file
+    json_predictions = {}
     for split_name, split in predictions.items():
+        pred_paths = []
         for sample in split:
-            _save_hdf5(sample[0], sample[2], output_folder)
+            hdf5_path = output_folder / f"{sample[0]}.hdf5"
+            _save_hdf5(sample[3], sample[1], sample[2], hdf5_path)
+            pred_paths.append([str(sample[0]), str(hdf5_path)])
+        json_predictions[split_name] = pred_paths
+
+    # Save path to hdf5 files into predictions.json
+    with predictions_file.open("w") as f:
+        json.dump(json_predictions, f, indent=2)
+    logger.info(f"Predictions saved to `{str(predictions_file)}`")
-- 
GitLab