From c2e0c01c3bca3bde40b78c0ca18e840ad8865077 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Thu, 16 May 2024 15:27:56 +0200
Subject: [PATCH] [predict] Split predict into functions and implement in
 segmentation

---
 .../libs/classification/scripts/experiment.py |  6 +-
 .../libs/classification/scripts/predict.py    | 20 +++---
 .../tests/test_cli_classification.py          | 17 +++--
 src/mednet/libs/common/scripts/predict.py     | 14 ++--
 .../libs/segmentation/engine/__init__.py      |  0
 .../libs/segmentation/models/separate.py      |  6 +-
 .../libs/segmentation/scripts/experiment.py   | 11 ++--
 .../libs/segmentation/scripts/predict.py      | 65 +++++++++++--------
 8 files changed, 78 insertions(+), 61 deletions(-)
 create mode 100644 src/mednet/libs/segmentation/engine/__init__.py

diff --git a/src/mednet/libs/classification/scripts/experiment.py b/src/mednet/libs/classification/scripts/experiment.py
index 814f3b76..6f11690a 100644
--- a/src/mednet/libs/classification/scripts/experiment.py
+++ b/src/mednet/libs/classification/scripts/experiment.py
@@ -111,11 +111,11 @@ def experiment(
 
     from .predict import predict
 
-    predictions_output = output_folder / "predictions.json"
+    predictions_output = output_folder / "predictions"
 
     ctx.invoke(
         predict,
-        output=predictions_output,
+        output_folder=predictions_output,
         model=model,
         datamodule=datamodule,
         device=device,
@@ -135,7 +135,7 @@ def experiment(
 
     ctx.invoke(
         evaluate,
-        predictions=predictions_output,
+        predictions=predictions_output / "predictions.json",
         output_folder=output_folder,
         threshold="validation",
     )
diff --git a/src/mednet/libs/classification/scripts/predict.py b/src/mednet/libs/classification/scripts/predict.py
index 04e9908d..2691727d 100644
--- a/src/mednet/libs/classification/scripts/predict.py
+++ b/src/mednet/libs/classification/scripts/predict.py
@@ -34,7 +34,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 @reusable_options
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def predict(
-    output,
+    output_folder,
     model,
     datamodule,
     batch_size,
@@ -56,22 +56,24 @@ def predict(
         setup_datamodule,
     )
 
+    predictions_file = output_folder / "predictions.json"
+    predictions_file.parent.mkdir(parents=True, exist_ok=True)
+
     setup_datamodule(datamodule, model, batch_size, parallel)
     model = load_checkpoint(model, weight)
     device_manager = DeviceManager(device)
-    save_json_data(datamodule, model, output, device_manager)
+    save_json_data(datamodule, model, predictions_file, device_manager)
 
     predictions = run(model, datamodule, device_manager)
 
-    output.parent.mkdir(parents=True, exist_ok=True)
-    if output.exists():
-        backup = output.parent / (output.name + "~")
+    if predictions_file.exists():
+        backup = predictions_file.parent / (predictions_file.name + "~")
         logger.warning(
-            f"Output predictions file `{str(output)}` exists - "
+            f"Output predictions file `{str(predictions_file)}` exists - "
             f"backing it up to `{str(backup)}`...",
         )
-        shutil.copy(output, backup)
+        shutil.copy(predictions_file, backup)
 
-    with output.open("w") as f:
+    with predictions_file.open("w") as f:
         json.dump(predictions, f, indent=2)
-    logger.info(f"Predictions saved to `{str(output)}`")
+    logger.info(f"Predictions saved to `{str(predictions_file)}`")
diff --git a/src/mednet/libs/classification/tests/test_cli_classification.py b/src/mednet/libs/classification/tests/test_cli_classification.py
index f735ae17..99e70d8f 100644
--- a/src/mednet/libs/classification/tests/test_cli_classification.py
+++ b/src/mednet/libs/classification/tests/test_cli_classification.py
@@ -329,7 +329,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir):
     runner = CliRunner()
 
     with stdout_logging() as buf:
-        output = temporary_basedir / "predictions.json"
+        output = temporary_basedir / "predictions"
         last = _get_checkpoint_from_alias(
             temporary_basedir / "results",
             "periodic",
@@ -343,7 +343,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir):
                 "-vv",
                 "--batch-size=1",
                 f"--weight={str(last)}",
-                f"--output={str(output)}",
+                f"--output-folder={str(output)}",
             ],
         )
         _assert_exit_0(result)
@@ -379,7 +379,8 @@ def test_evaluate_pasa_montgomery(temporary_basedir):
     runner = CliRunner()
 
     with stdout_logging() as buf:
-        prediction_path = temporary_basedir / "predictions.json"
+        prediction_path = temporary_basedir / "predictions"
+        predictions_file = prediction_path / "predictions.json"
         evaluation_filename = "evaluation.json"
         evaluation_file = temporary_basedir / evaluation_filename
         result = runner.invoke(
@@ -387,7 +388,7 @@ def test_evaluate_pasa_montgomery(temporary_basedir):
             [
                 "-vv",
                 "montgomery",
-                f"--predictions={str(prediction_path)}",
+                f"--predictions={predictions_file}",
                 f"--output-folder={str(temporary_basedir)}",
                 "--threshold=test",
             ],
@@ -440,9 +441,11 @@ def test_experiment(temporary_basedir):
     _assert_exit_0(result)
 
     assert (output_folder / "model" / "meta.json").exists()
-    assert (output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt").exists()
-    assert (output_folder / "predictions.json").exists()
-    assert (output_folder / "predictions.meta.json").exists()
+    assert (
+        output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt"
+    ).exists()
+    assert (output_folder / "predictions" / "predictions.json").exists()
+    assert (output_folder / "predictions" / "predictions.meta.json").exists()
 
     # Need to glob because we cannot be sure of the checkpoint with lowest validation loss
     assert (
diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py
index 2cde0d8d..48fbf4d1 100644
--- a/src/mednet/libs/common/scripts/predict.py
+++ b/src/mednet/libs/common/scripts/predict.py
@@ -32,20 +32,18 @@ def reusable_options(f):
     """
 
     @click.option(
-        "--output",
+        "--output-folder",
         "-o",
-        help="""Path to a JSON file in which to save predictions for all samples in the
-        input DataModule (leading directories are created if they do not
-        exist).""",
+        help="Directory in which to save predictions (created if does not exist)",
         required=True,
-        default="predictions.json",
-        cls=ResourceOption,
         type=click.Path(
-            file_okay=True,
-            dir_okay=False,
+            file_okay=False,
+            dir_okay=True,
             writable=True,
             path_type=pathlib.Path,
         ),
+        default="predictions",
+        cls=ResourceOption,
     )
     @click.option(
         "--model",
diff --git a/src/mednet/libs/segmentation/engine/__init__.py b/src/mednet/libs/segmentation/engine/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py
index 7f520a4a..6b9bb5de 100644
--- a/src/mednet/libs/segmentation/models/separate.py
+++ b/src/mednet/libs/segmentation/models/separate.py
@@ -5,7 +5,6 @@
 
 import typing
 
-import torch
 from mednet.libs.common.data.typing import Sample
 
 from .typing import BinaryPrediction, MultiClassPrediction
@@ -27,7 +26,7 @@ def _as_predictions(
     list[BinaryPrediction | MultiClassPrediction]
         A list of typed predictions that can be saved to disk.
     """
-    return [(v[1]["name"], v[1]["target"], v[0].item()) for v in samples]
+    return [(v[1]["name"], v[1]["target"], v[0]) for v in samples]
 
 
 def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
@@ -58,4 +57,5 @@ def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
         {key: value[i] for key, value in batch[1].items()}
         for i in range(len(batch[0]))
     ]
-    return _as_predictions(zip(torch.flatten(batch[0]), metadata))
+
+    return _as_predictions(zip(batch[0], metadata))
diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py
index 0760adde..ccae0eae 100644
--- a/src/mednet/libs/segmentation/scripts/experiment.py
+++ b/src/mednet/libs/segmentation/scripts/experiment.py
@@ -112,11 +112,11 @@ def experiment(
 
     from .predict import predict
 
-    predictions_output = output_folder / "predictions.json"
+    predictions_output = output_folder / "predictions"
 
     ctx.invoke(
         predict,
-        output=predictions_output,
+        output_folder=predictions_output,
         model=model,
         datamodule=datamodule,
         device=device,
@@ -131,7 +131,7 @@ def experiment(
         f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}"
     )
 
-    """evaluation_start_timestamp = datetime.now()
+    evaluation_start_timestamp = datetime.now()
     logger.info(f"Started evaluation at {evaluation_start_timestamp}")
 
     from .evaluate import evaluate
@@ -140,14 +140,15 @@ def experiment(
         evaluate,
         predictions=predictions_output,
         output_folder=output_folder,
-        threshold="validation",
+        # threshold="validation",
+        threshold=0.5,
     )
 
     evaluation_stop_timestamp = datetime.now()
     logger.info(f"Ended prediction in {evaluation_stop_timestamp}")
     logger.info(
         f"Prediction runtime: {evaluation_stop_timestamp-evaluation_start_timestamp}"
-    )"""
+    )
 
     experiment_stop_timestamp = datetime.now()
     logger.info(
diff --git a/src/mednet/libs/segmentation/scripts/predict.py b/src/mednet/libs/segmentation/scripts/predict.py
index 9e0b44b0..b7492f53 100644
--- a/src/mednet/libs/segmentation/scripts/predict.py
+++ b/src/mednet/libs/segmentation/scripts/predict.py
@@ -3,15 +3,47 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 
+import pathlib
+
 import click
+import h5py
+import PIL
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 from mednet.libs.common.scripts.click import ConfigCommand
 from mednet.libs.common.scripts.predict import reusable_options
+from tqdm import tqdm
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
+def _save_hdf5(
+    stem: pathlib.Path, prob: PIL.Image.Image, output_folder: pathlib.Path
+):
+    """Save prediction maps as image in the same format as the test image.
+
+    Parameters
+    ----------
+    stem
+        Name of the file without extension on the original dataset.
+
+    prob
+        Monochrome Image with prediction maps.
+
+    output_folder
+        Directory in which to store predictions.
+    """
+
+    fullpath = output_folder / f"{stem}.hdf5"
+    tqdm.write(f"Saving {fullpath}...")
+    fullpath.parent.mkdir(parents=True, exist_ok=True)
+    with h5py.File(fullpath, "w") as f:
+        data = prob.squeeze(0).numpy()
+        f.create_dataset(
+            "array", data=data, compression="gzip", compression_opts=9
+        )
+
+
 @click.command(
     entry_point_group="mednet.libs.segmentation.config",
     cls=ConfigCommand,
@@ -34,7 +66,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 @reusable_options
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def predict(
-    output,
+    output_folder,
     model,
     datamodule,
     batch_size,
@@ -45,9 +77,6 @@ def predict(
 ) -> None:  # numpydoc ignore=PR01
     """Run inference (generates scores) on all input images, using a pre-trained model."""
 
-    import json
-    import shutil
-
     from mednet.libs.classification.engine.predictor import run
     from mednet.libs.common.engine.device import DeviceManager
     from mednet.libs.common.scripts.predict import (
@@ -56,31 +85,15 @@ def predict(
         setup_datamodule,
     )
 
+    predictions_file = output_folder / "predictions.json"
+
     setup_datamodule(datamodule, model, batch_size, parallel)
     model = load_checkpoint(model, weight)
     device_manager = DeviceManager(device)
-    save_json_data(datamodule, model, output, device_manager)
+    save_json_data(datamodule, model, predictions_file, device_manager)
 
     predictions = run(model, datamodule, device_manager)
 
-    output.parent.mkdir(parents=True, exist_ok=True)
-    if output.exists():
-        backup = output.parent / (output.name + "~")
-        logger.warning(
-            f"Output predictions file `{str(output)}` exists - "
-            f"backing it up to `{str(backup)}`...",
-        )
-        shutil.copy(output, backup)
-
-    # Remove targets from predictions, as they are images and not serializable
-    # A better solution should be found
-    serializable_predictions = {}
-    for split_name, sample in predictions.items():
-        split_predictions = []
-        for s in sample:
-            split_predictions.append((s[0], s[2]))
-        serializable_predictions[split_name] = split_predictions
-
-    with output.open("w") as f:
-        json.dump(serializable_predictions, f, indent=2)
-    logger.info(f"Predictions saved to `{str(output)}`")
+    for split_name, split in predictions.items():
+        for sample in split:
+            _save_hdf5(sample[0], sample[2], output_folder)
-- 
GitLab