From d354514cfb59f811534ce92a56b8392a737fc135 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 14 May 2024 11:14:27 +0200
Subject: [PATCH] [predict] Split predict script into functions, fix
 segmentation

---
 .../libs/classification/scripts/predict.py    |  40 +++--
 src/mednet/libs/common/scripts/predict.py     |  83 +++++-----
 src/mednet/libs/segmentation/models/lwnet.py  |   4 +-
 .../libs/segmentation/models/separate.py      |   2 +-
 src/mednet/libs/segmentation/scripts/cli.py   |   7 +-
 .../libs/segmentation/scripts/experiment.py   | 155 ++++++++++++++++++
 .../libs/segmentation/scripts/predict.py      |  86 ++++++++++
 7 files changed, 320 insertions(+), 57 deletions(-)
 create mode 100644 src/mednet/libs/segmentation/scripts/experiment.py
 create mode 100644 src/mednet/libs/segmentation/scripts/predict.py

diff --git a/src/mednet/libs/classification/scripts/predict.py b/src/mednet/libs/classification/scripts/predict.py
index 50d7d03d..04e9908d 100644
--- a/src/mednet/libs/classification/scripts/predict.py
+++ b/src/mednet/libs/classification/scripts/predict.py
@@ -7,7 +7,6 @@ import click
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 from mednet.libs.common.scripts.click import ConfigCommand
-from mednet.libs.common.scripts.predict import predict as predict_script
 from mednet.libs.common.scripts.predict import reusable_options
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@@ -45,13 +44,34 @@ def predict(
     **_,
 ) -> None:  # numpydoc ignore=PR01
     """Run inference (generates scores) on all input images, using a pre-trained model."""
-    predict_script(
-        output,
-        model,
-        datamodule,
-        batch_size,
-        device,
-        weight,
-        parallel,
-        **_,
+
+    import json
+    import shutil
+
+    from mednet.libs.classification.engine.predictor import run
+    from mednet.libs.common.engine.device import DeviceManager
+    from mednet.libs.common.scripts.predict import (
+        load_checkpoint,
+        save_json_data,
+        setup_datamodule,
     )
+
+    setup_datamodule(datamodule, model, batch_size, parallel)
+    model = load_checkpoint(model, weight)
+    device_manager = DeviceManager(device)
+    save_json_data(datamodule, model, output, device_manager)
+
+    predictions = run(model, datamodule, device_manager)
+
+    output.parent.mkdir(parents=True, exist_ok=True)
+    if output.exists():
+        backup = output.parent / (output.name + "~")
+        logger.warning(
+            f"Output predictions file `{str(output)}` exists - "
+            f"backing it up to `{str(backup)}`...",
+        )
+        shutil.copy(output, backup)
+
+    with output.open("w") as f:
+        json.dump(predictions, f, indent=2)
+    logger.info(f"Predictions saved to `{str(output)}`")
diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py
index 11cd6336..2cde0d8d 100644
--- a/src/mednet/libs/common/scripts/predict.py
+++ b/src/mednet/libs/common/scripts/predict.py
@@ -4,10 +4,12 @@
 
 import functools
 import pathlib
+import typing
 
 import click
 from clapper.click import ResourceOption
 from clapper.logging import setup
+from mednet.libs.common.models.model import Model
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -120,34 +122,13 @@ def reusable_options(f):
     return wrapper_reusable_options
 
 
-def predict(
-    output,
-    model,
+def setup_datamodule(
     datamodule,
+    model,
     batch_size,
-    device,
-    weight,
     parallel,
-    **_,
 ) -> None:  # numpydoc ignore=PR01
-    """Run inference (generates scores) on all input images, using a pre-trained model."""
-    import json
-    import shutil
-    import typing
-
-    from mednet.libs.classification.engine.predictor import run
-    from mednet.libs.common.engine.device import DeviceManager
-    from mednet.libs.common.utils.checkpointer import (
-        get_checkpoint_to_run_inference,
-    )
-
-    from .utils import (
-        device_properties,
-        execution_metadata,
-        model_summary,
-        save_json_with_backup,
-    )
-
+    """Configure and set up the datamodule."""
     datamodule.set_chunk_size(batch_size, 1)
     datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
@@ -155,15 +136,48 @@ def predict(
     datamodule.prepare_data()
     datamodule.setup(stage="predict")
 
+
+def load_checkpoint(model: Model, weight: pathlib.Path) -> Model:
+    """Load a model checkpoint for prediction.
+
+    Parameters
+    ----------
+    model
+        Instance of a model.
+    weight
+        The base directory containing either the "best", "last" or "periodic"
+        checkpoint to start the training session from.
+
+    Returns
+    -------
+        An instance of the model loaded from the checkpoint.
+    """
+    from mednet.libs.common.utils.checkpointer import (
+        get_checkpoint_to_run_inference,
+    )
+
     if weight.is_dir():
         weight = get_checkpoint_to_run_inference(weight)
 
     logger.info(f"Loading checkpoint from `{weight}`...")
-    model = type(model).load_from_checkpoint(weight, strict=False)
+    return type(model).load_from_checkpoint(weight, strict=False)
 
-    device_manager = DeviceManager(device)
 
-    # register metadata
+def save_json_data(
+    datamodule,
+    model,
+    output,
+    device_manager,
+) -> None:  # numpydoc ignore=PR01
+    """Save prediction hyperparameters into a .json file."""
+
+    from .utils import (
+        device_properties,
+        execution_metadata,
+        model_summary,
+        save_json_with_backup,
+    )
+
     json_data: dict[str, typing.Any] = execution_metadata()
     json_data.update(device_properties(device_manager.device_type))
     json_data.update(
@@ -176,18 +190,3 @@ def predict(
     json_data.update(model_summary(model))
     json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
     save_json_with_backup(output.with_suffix(".meta.json"), json_data)
-
-    predictions = run(model, datamodule, device_manager)
-
-    output.parent.mkdir(parents=True, exist_ok=True)
-    if output.exists():
-        backup = output.parent / (output.name + "~")
-        logger.warning(
-            f"Output predictions file `{str(output)}` exists - "
-            f"backing it up to `{str(backup)}`...",
-        )
-        shutil.copy(output, backup)
-
-    with output.open("w") as f:
-        json.dump(predictions, f, indent=2)
-    logger.info(f"Predictions saved to `{str(output)}`")
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index c3190f0b..593c2f96 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -319,8 +319,8 @@ class LittleWNet(Model):
         return self._validation_loss(outputs, ground_truths, masks)
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        outputs = self(batch[0])
-        probabilities = torch.sigmoid(outputs)
+        output = self(batch[0])[1]
+        probabilities = torch.sigmoid(output)
         return separate((probabilities, batch[1]))
 
     def configure_optimizers(self):
diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py
index 8abd2329..7f520a4a 100644
--- a/src/mednet/libs/segmentation/models/separate.py
+++ b/src/mednet/libs/segmentation/models/separate.py
@@ -27,7 +27,7 @@ def _as_predictions(
     list[BinaryPrediction | MultiClassPrediction]
         A list of typed predictions that can be saved to disk.
     """
-    return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples]
+    return [(v[1]["name"], v[1]["target"], v[0].item()) for v in samples]
 
 
 def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
diff --git a/src/mednet/libs/segmentation/scripts/cli.py b/src/mednet/libs/segmentation/scripts/cli.py
index 52ce017f..49ba9821 100644
--- a/src/mednet/libs/segmentation/scripts/cli.py
+++ b/src/mednet/libs/segmentation/scripts/cli.py
@@ -12,10 +12,10 @@ from . import (
     # compare,
     config,
     database,
+    predict,
     # evaluate,
     # experiment,
     # mkmask,
-    # predict,
     # significance,
     train,
 )
@@ -37,12 +37,15 @@ segmentation.add_command(database.database)
 # segmentation.add_command(evaluate.evaluate)
 # segmentation.add_command(experiment.experiment)
 # segmentation.add_command(mkmask.mkmask)
-# segmentation.add_command(predict.predict)
 # segmentation.add_command(significance.significance)
 segmentation.add_command(train.train)
+segmentation.add_command(predict.predict)
 segmentation.add_command(
     importlib.import_module(
         "mednet.libs.common.scripts.train_analysis",
         package=__name__,
     ).train_analysis,
 )
+segmentation.add_command(
+    importlib.import_module("..experiment", package=__name__).experiment,
+)
diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py
new file mode 100644
index 00000000..0760adde
--- /dev/null
+++ b/src/mednet/libs/segmentation/scripts/experiment.py
@@ -0,0 +1,155 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from datetime import datetime
+
+import click
+from clapper.click import ConfigCommand, ResourceOption, verbosity_option
+from clapper.logging import setup
+
+from .train import reusable_options as training_options
+
+# avoids X11/graphical desktop requirement when creating plots
+__import__("matplotlib").use("agg")
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+@click.command(
+    entry_point_group="mednet.libs.segmentation.config",
+    cls=ConfigCommand,
+    epilog="""Examples:
+
+\b
+  1. Train a pasa model with montgomery dataset, on the CPU, for only two
+     epochs, then runs inference and evaluation on stock datasets, report
+     performance as a table and figures:
+
+     .. code:: sh
+
+        $ mednet experiment -vv pasa montgomery --epochs=2
+""",
+)
+@training_options
+@verbosity_option(logger=logger, cls=ResourceOption)
+@click.pass_context
+def experiment(
+    ctx,
+    model,
+    output_folder,
+    epochs,
+    batch_size,
+    batch_chunk_count,
+    drop_incomplete_batch,
+    datamodule,
+    validation_period,
+    device,
+    cache_samples,
+    seed,
+    parallel,
+    monitoring_interval,
+    **_,
+):  # numpydoc ignore=PR01
+    r"""Run a complete experiment, from training, to prediction and evaluation.
+
+    This script is just a wrapper around the individual scripts for training,
+    running prediction, and evaluating.  It organises the output in a preset way::
+
+        \b
+       └─ <output-folder>/
+          ├── command.sh
+          ├── model/  # the generated model will be here
+          ├── predictions.json  # the prediction outputs for the sets
+          └── evaluation/  # the outputs of the evaluations for the sets
+    """
+
+    experiment_start_timestamp = datetime.now()
+
+    train_start_timestamp = datetime.now()
+    logger.info(f"Started training at {train_start_timestamp}")
+
+    from .train import train
+
+    train_output_folder = output_folder / "model"
+    ctx.invoke(
+        train,
+        model=model,
+        output_folder=train_output_folder,
+        epochs=epochs,
+        batch_size=batch_size,
+        batch_chunk_count=batch_chunk_count,
+        drop_incomplete_batch=drop_incomplete_batch,
+        datamodule=datamodule,
+        validation_period=validation_period,
+        device=device,
+        cache_samples=cache_samples,
+        seed=seed,
+        parallel=parallel,
+        monitoring_interval=monitoring_interval,
+    )
+    train_stop_timestamp = datetime.now()
+
+    logger.info(f"Ended training in {train_stop_timestamp}")
+    logger.info(
+        f"Training runtime: {train_stop_timestamp-train_start_timestamp}"
+    )
+
+    logger.info("Started train analysis")
+    from mednet.libs.common.scripts.train_analysis import train_analysis
+
+    logdir = train_output_folder / "logs"
+    ctx.invoke(
+        train_analysis,
+        logdir=logdir,
+        output_folder=train_output_folder,
+    )
+
+    logger.info("Ended train analysis")
+
+    predict_start_timestamp = datetime.now()
+    logger.info(f"Started prediction at {predict_start_timestamp}")
+
+    from .predict import predict
+
+    predictions_output = output_folder / "predictions.json"
+
+    ctx.invoke(
+        predict,
+        output=predictions_output,
+        model=model,
+        datamodule=datamodule,
+        device=device,
+        weight=train_output_folder,
+        batch_size=batch_size,
+        parallel=parallel,
+    )
+
+    predict_stop_timestamp = datetime.now()
+    logger.info(f"Ended prediction in {predict_stop_timestamp}")
+    logger.info(
+        f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}"
+    )
+
+    """evaluation_start_timestamp = datetime.now()
+    logger.info(f"Started evaluation at {evaluation_start_timestamp}")
+
+    from .evaluate import evaluate
+
+    ctx.invoke(
+        evaluate,
+        predictions=predictions_output,
+        output_folder=output_folder,
+        threshold="validation",
+    )
+
+    evaluation_stop_timestamp = datetime.now()
+    logger.info(f"Ended prediction in {evaluation_stop_timestamp}")
+    logger.info(
+        f"Prediction runtime: {evaluation_stop_timestamp-evaluation_start_timestamp}"
+    )"""
+
+    experiment_stop_timestamp = datetime.now()
+    logger.info(
+        f"Total experiment runtime: {experiment_stop_timestamp-experiment_start_timestamp}"
+    )
diff --git a/src/mednet/libs/segmentation/scripts/predict.py b/src/mednet/libs/segmentation/scripts/predict.py
new file mode 100644
index 00000000..9e0b44b0
--- /dev/null
+++ b/src/mednet/libs/segmentation/scripts/predict.py
@@ -0,0 +1,86 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+
+import click
+from clapper.click import ResourceOption, verbosity_option
+from clapper.logging import setup
+from mednet.libs.common.scripts.click import ConfigCommand
+from mednet.libs.common.scripts.predict import reusable_options
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+@click.command(
+    entry_point_group="mednet.libs.segmentation.config",
+    cls=ConfigCommand,
+    epilog="""Examples:
+
+1. Run prediction on an existing DataModule configuration:
+
+   .. code:: sh
+
+      mednet predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json
+
+2. Enable multi-processing data loading with 6 processes:
+
+   .. code:: sh
+
+      mednet predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json
+
+""",
+)
+@reusable_options
+@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
+def predict(
+    output,
+    model,
+    datamodule,
+    batch_size,
+    device,
+    weight,
+    parallel,
+    **_,
+) -> None:  # numpydoc ignore=PR01
+    """Run inference (generates scores) on all input images, using a pre-trained model."""
+
+    import json
+    import shutil
+
+    from mednet.libs.classification.engine.predictor import run
+    from mednet.libs.common.engine.device import DeviceManager
+    from mednet.libs.common.scripts.predict import (
+        load_checkpoint,
+        save_json_data,
+        setup_datamodule,
+    )
+
+    setup_datamodule(datamodule, model, batch_size, parallel)
+    model = load_checkpoint(model, weight)
+    device_manager = DeviceManager(device)
+    save_json_data(datamodule, model, output, device_manager)
+
+    predictions = run(model, datamodule, device_manager)
+
+    output.parent.mkdir(parents=True, exist_ok=True)
+    if output.exists():
+        backup = output.parent / (output.name + "~")
+        logger.warning(
+            f"Output predictions file `{str(output)}` exists - "
+            f"backing it up to `{str(backup)}`...",
+        )
+        shutil.copy(output, backup)
+
+    # Remove targets from predictions, as they are images and not serializable
+    # A better solution should be found
+    serializable_predictions = {}
+    for split_name, sample in predictions.items():
+        split_predictions = []
+        for s in sample:
+            split_predictions.append((s[0], s[2]))
+        serializable_predictions[split_name] = split_predictions
+
+    with output.open("w") as f:
+        json.dump(serializable_predictions, f, indent=2)
+    logger.info(f"Predictions saved to `{str(output)}`")
-- 
GitLab