From f869f81eeba1faa55cdfe998b10cb3e965cdcbb9 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 25 Jun 2024 15:52:46 +0200
Subject: [PATCH] [libs.segmentation.scripts.view] Implement new view script
 using the output of the predict step; Implement colorization/transparency
 options

---
 src/mednet/libs/segmentation/engine/viewer.py |  93 +++++
 src/mednet/libs/segmentation/scripts/cli.py   |   6 -
 src/mednet/libs/segmentation/scripts/view.py  | 364 ++++++++----------
 3 files changed, 259 insertions(+), 204 deletions(-)
 create mode 100644 src/mednet/libs/segmentation/engine/viewer.py

diff --git a/src/mednet/libs/segmentation/engine/viewer.py b/src/mednet/libs/segmentation/engine/viewer.py
new file mode 100644
index 00000000..6699f683
--- /dev/null
+++ b/src/mednet/libs/segmentation/engine/viewer.py
@@ -0,0 +1,93 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import pathlib
+
+import h5py
+import numpy
+import numpy.typing
+import PIL.Image
+import PIL.ImageOps
+import torch
+import torchvision.transforms.functional
+
+from .evaluator import tfpn_masks
+
+
+def view(
+    basedir: pathlib.Path,
+    stem: str,
+    threshold: float,
+    show_errors: bool,
+    tp_color: tuple[int, int, int],
+    fp_color: tuple[int, int, int],
+    fn_color: tuple[int, int, int],
+    alpha: float,
+) -> PIL.Image.Image:
+    """Create an segmentation map visualisation.
+
+    Parameters
+    ----------
+    basedir
+        Base directory where the prediction indicated by ``stem`` is stored.
+    stem
+        Name of the HDF5 file containing the predictions, as output by the
+        ``predict`` CLI.
+    threshold
+        The threshold to apply to the probability map loaded from the HDF5
+        file.
+    show_errors
+        If set to ``True``, then colours false-positives (in red), and false
+        negatives (in green).
+    tp_color
+        Tuple that indicates which color to use for displaying true-positives.
+    fp_color
+        Tuple that indicates which color to use for displaying false-positives.
+    fn_color
+        Tuple that indicates which color to use for displaying false-negatives.
+    alpha
+        How transparent will the overlay be.
+
+    Returns
+    -------
+        An image with an overlayed segmentation map that can be saved or
+        displayed.
+    """
+
+    def _to_pil(arr: numpy.typing.NDArray[numpy.float32]) -> PIL.Image.Image:
+        return torchvision.transforms.functional.to_pil_image(torch.Tensor(arr))
+
+    with h5py.File(basedir / stem, "r") as f:
+        image: numpy.typing.NDArray[numpy.float32] = numpy.array(f.get("image"))
+        pred: numpy.typing.NDArray[numpy.float32] = numpy.array(f.get("prediction"))
+        target: numpy.typing.NDArray[numpy.bool_] = numpy.array(f.get("target"))
+        mask: numpy.typing.NDArray[numpy.bool_] = numpy.array(f.get("mask"))
+
+    image *= mask
+    pred *= mask
+    target = numpy.logical_and(target, mask)
+
+    if show_errors:
+        tp, fp, _, fn = tfpn_masks(pred, target, threshold)
+
+        # change to PIL representation
+        tp_pil = _to_pil(tp.astype(numpy.float32))
+        tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0, 0, 0), tp_color)
+
+        fp_pil = _to_pil(fp.astype(numpy.float32))
+        fp_pil_colored = PIL.ImageOps.colorize(fp_pil, (0, 0, 0), fp_color)
+
+        fn_pil = _to_pil(fn.astype(numpy.float32))
+        fn_pil_colored = PIL.ImageOps.colorize(fn_pil, (0, 0, 0), fn_color)
+
+        tp_pil_colored.paste(fp_pil_colored, mask=fp_pil)
+        tp_pil_colored.paste(fn_pil_colored, mask=fn_pil)
+
+    else:
+        overlay = pred >= threshold
+        tp_pil = _to_pil(overlay.astype(numpy.float32))
+        tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0, 0, 0), tp_color)
+
+    retval = _to_pil(image)
+    return PIL.Image.blend(retval, tp_pil_colored, alpha)
diff --git a/src/mednet/libs/segmentation/scripts/cli.py b/src/mednet/libs/segmentation/scripts/cli.py
index 2331fd13..f2841daf 100644
--- a/src/mednet/libs/segmentation/scripts/cli.py
+++ b/src/mednet/libs/segmentation/scripts/cli.py
@@ -9,13 +9,10 @@ from clapper.click import AliasedGroup
 
 from . import (
     # analyze,
-    # compare,
     config,
     database,
     evaluate,
     predict,
-    # mkmask,
-    # significance,
     train,
     view,
 )
@@ -31,11 +28,8 @@ def segmentation():
 
 
 # segmentation.add_command(analyze.analyze)
-# segmentation.add_command(compare.compare)
 segmentation.add_command(config.config)
 segmentation.add_command(database.database)
-# segmentation.add_command(mkmask.mkmask)
-# segmentation.add_command(significance.significance)
 segmentation.add_command(train.train)
 segmentation.add_command(predict.predict)
 segmentation.add_command(evaluate.evaluate)
diff --git a/src/mednet/libs/segmentation/scripts/view.py b/src/mednet/libs/segmentation/scripts/view.py
index 30c3b355..e0cb9706 100644
--- a/src/mednet/libs/segmentation/scripts/view.py
+++ b/src/mednet/libs/segmentation/scripts/view.py
@@ -1,171 +1,21 @@
-# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import json
 import pathlib
+import typing
 
 import click
-import h5py
-import PIL.Image
-import torch
-from clapper.click import ResourceOption
+import tqdm
+from clapper.click import ResourceOption, verbosity_option
+from clapper.logging import setup
 from mednet.libs.common.scripts.click import ConfigCommand
-from PIL import ImageColor
-from PIL.ImageChops import invert, logical_and
-from torchvision.transforms.functional import to_pil_image
-
-
-def get_tp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
-    """Compute the true positive mask.
-
-    Parameters
-    ----------
-    binary_image
-        B/W image to compare to the target.
-    binary_target
-        B/W reference image.
-
-    Returns
-    -------
-        Image with white pixels where both image and target are white, black pixels otherwise.
-    """
-    return logical_and(binary_image, binary_target)
-
-
-def get_tn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
-    """Compute the false positive mask.
-
-    Parameters
-    ----------
-    binary_image
-        B/W image to compare to the target.
-    binary_target
-        B/W reference image.
-
-    Returns
-    -------
-        Image with white pixels where both image and target are black, black pixels otherwise.
-    """
-    return logical_and(invert(binary_image), invert(binary_target))
-
-
-def get_fp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
-    """Compute the true positive mask.
-
-    Parameters
-    ----------
-    binary_image
-        B/W image to compare to the target.
-    binary_target
-        B/W reference image.
-
-    Returns
-    -------
-        Image with white pixels where image is white and target is black, black pixels otherwise.
-    """
-    return logical_and(binary_image, invert(binary_target))
-
-
-def get_fn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
-    """Compute the true positive mask.
-
-    Parameters
-    ----------
-    binary_image
-        B/W image to compare to the target.
-    binary_target
-        B/W reference image.
-
-    Returns
-    -------
-        Image with white pixels where image is black and target is white, black pixels otherwise.
-    """
-    return logical_and(invert(binary_image), binary_target)
-
-
-def get_masks(
-    binary_prediction_image: PIL.Image, binary_target_image: PIL.Image
-) -> tuple[PIL.Image.Image, PIL.Image, PIL.Image, PIL.Image]:
-    """Given a B/W binary image and a target, return the tp, tn, fp, fn masks.
-
-    Parameters
-    ----------
-    binary_prediction_image
-        B/W image.
-    binary_target_image
-        B/W reference image.
-
-    Returns
-    -------
-        The tp, tn, fp, fn masks
-    """
-    tp_mask = get_tp_mask(binary_prediction_image, binary_target_image)
-    tn_mask = get_tn_mask(binary_prediction_image, binary_target_image)
-    fp_mask = get_fp_mask(binary_prediction_image, binary_target_image)
-    fn_mask = get_fn_mask(binary_prediction_image, binary_target_image)
-
-    return tp_mask, tn_mask, fp_mask, fn_mask
-
-
-def load_image_from_hdf5(filepath: pathlib.Path, key: str) -> PIL.Image:
-    """Load an image located in an hdf5 file by key.
-
-    Parameters
-    ----------
-    filepath
-        Path to an hdf5 file.
-    key
-        Key to search for in the hdf5 file.
-
-    Returns
-    -------
-        The loaded PIL.Image.
-    """
-    with h5py.File(filepath, "r") as f:
-        img = to_pil_image(torch.from_numpy(f.get(key)[:]))
-
-    return img  # noqa: RET504
-
-
-def image_to_binary(image: PIL.Image, threshold: int = 127) -> PIL.Image:
-    """Change the mode of a PIL image to '1' (binary) using a threshold.
-
-    Parameters
-    ----------
-    image
-        The image to convert to binary mode.
-    threshold
-        The threshold to use for convertion, with p <= threshold = black, p > threshold = white.
-
-    Returns
-    -------
-        The binary image.
-    """
-    image = image.point(lambda p: 255 if p > threshold else 0)
-    return image.convert("1")
-
-
-def color_with_mask(image: PIL.Image, mask: PIL.Image, color: ImageColor) -> PIL.Image:
-    """Colorize the image with a given color by using a mask.
-
-    Parameters
-    ----------
-    image
-        The image to colorize.
-    mask
-        Mask used to indicate where to apply the color.
-    color
-        The color to apply.
-
-    Returns
-    -------
-        The colorized image.
-    """
-    color_plane = PIL.Image.new(mode="RGB", size=image.size, color=color)
-
-    image = image.convert("RGB")
-    image.paste(color_plane, mask)
-    return image
+from mednet.libs.segmentation.engine.evaluator import SUPPORTED_METRIC_TYPE
+
+from .evaluate import validate_threshold
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
 @click.command(
@@ -174,66 +24,184 @@ def color_with_mask(image: PIL.Image, mask: PIL.Image, color: ImageColor) -> PIL
     epilog="""Examples:
 
 \b
-  1. Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized:
+  1. Runs evaluation on an existing dataset configuration:
 
      .. code:: sh
 
-        $ mednet segmentation view -f results/predictions/test/test.hdf5 -o colorized_prediction.png
-
+        $ mednet segmentation view -vv --predictions=path/to/predictions.json --output-folder=path/to/results
 """,
 )
 @click.option(
-    "--hdf5-file",
-    "-f",
-    help="File in which predictions are currently stored",
+    "--predictions",
+    "-p",
+    help="""Path to the JSON file describing available predictions. The actual
+    predictions are supposed to lie on the same folder.""",
     required=True,
     type=click.Path(
         file_okay=True,
         dir_okay=False,
-        writable=True,
+        writable=False,
         path_type=pathlib.Path,
     ),
     cls=ResourceOption,
 )
 @click.option(
-    "--output-file",
+    "--output-folder",
     "-o",
-    help="File in which to store the result (created if does not exist)",
+    help="Directory in which to store results (created if does not exist)",
     required=True,
     type=click.Path(
-        file_okay=True,
-        dir_okay=False,
+        file_okay=False,
+        dir_okay=True,
         writable=True,
         path_type=pathlib.Path,
     ),
-    default="segmentation.png",
+    default="results",
+    cls=ResourceOption,
+)
+@click.option(
+    "--threshold",
+    "-t",
+    help="""This number is used to define positives and negatives from
+    probability maps, and used to report metrics based on a threshold chosen *a
+    priori*. It can be set to a floating-point value, or to the name of dataset
+    split in ``--predictions``.
+    """,
+    default="0.5",
+    show_default=True,
+    required=False,
+    cls=ResourceOption,
+)
+@click.option(
+    "--metric",
+    "-m",
+    help="""If threshold is set to the name of a split in ``--predictions``,
+    then this parameter defines the metric function to be used to evaluate the
+    threshold at which the metric reaches its maximum value. All other splits
+    are evaluated with respect to this threshold.""",
+    default="f1",
+    type=click.Choice(typing.get_args(SUPPORTED_METRIC_TYPE), case_sensitive=True),
+    show_default=True,
+    required=True,
+    cls=ResourceOption,
+)
+@click.option(
+    "--steps",
+    "-s",
+    help="""Number of steps for evaluating metrics on various splits. This
+    value is used when drawing precision-recall plots, or when deciding the
+    highest metric value on splits.""",
+    default=100,
+    type=click.IntRange(10),
+    show_default=True,
+    required=True,
+    cls=ResourceOption,
+)
+@click.option(
+    "--show-errors/--no-show-errors",
+    "-e/-E",
+    help="""If set, then shows a colorized version of the segmentation map in
+    which false-positives are marked in red, and false-negatives in green.
+    True positives are always marked in white.""",
+    default=False,
+    show_default=True,
+    required=True,
     cls=ResourceOption,
 )
+@click.option(
+    "--alpha",
+    "-a",
+    help="""Defines the transparency weighting between the original image and
+    the predicted segmentation maps. A value of 1.0 makes the program output
+    only segmentation maps.  A value of 0.0 makes the program output only the
+    processed image.""",
+    default=0.6,
+    type=click.FloatRange(0.0, 1.0),
+    show_default=True,
+    required=True,
+    cls=ResourceOption,
+)
+@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def view(
-    hdf5_file: pathlib.Path,
-    output_file: pathlib.Path,
+    predictions: pathlib.Path,
+    output_folder: pathlib.Path,
+    threshold: str | float,
+    metric: str,
+    steps: int,
+    show_errors: bool,
+    alpha: float,
     **_,  # ignored
 ):  # numpydoc ignore=PR01
-    """Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized."""
-    colors_dict = {
-        "tp": ImageColor.getcolor("white", "RGB"),
-        "tn": ImageColor.getcolor("black", "RGB"),
-        "fp": ImageColor.getcolor("green", "RGB"),
-        "fn": ImageColor.getcolor("red", "RGB"),
-    }
-
-    pred_img = load_image_from_hdf5(hdf5_file, "img")
-    binary_pred_img = image_to_binary(pred_img, 127)
-
-    target_img = load_image_from_hdf5(hdf5_file, "target")
-    binary_target_img = image_to_binary(target_img, 127)
-
-    tp_mask, tn_mask, fp_mask, fn_mask = get_masks(binary_pred_img, binary_target_img)
-
-    colorized_image = pred_img
-    colorized_image = color_with_mask(colorized_image, tp_mask, colors_dict["tp"])
-    colorized_image = color_with_mask(colorized_image, tn_mask, colors_dict["tn"])
-    colorized_image = color_with_mask(colorized_image, fp_mask, colors_dict["fp"])
-    colorized_image = color_with_mask(colorized_image, fn_mask, colors_dict["fn"])
-
-    colorized_image.save(output_file)
+    """Evaluate predictions (from a model) on a segmentation task."""
+
+    import numpy
+    from mednet.libs.common.scripts.utils import (
+        execution_metadata,
+        save_json_with_backup,
+    )
+    from mednet.libs.segmentation.engine.evaluator import (
+        compute_metric,
+        load_count,
+        name2metric,
+    )
+    from mednet.libs.segmentation.engine.viewer import view
+
+    evaluation_filename = "evaluation.json"
+    evaluation_file = output_folder / evaluation_filename
+
+    with predictions.open("r") as f:
+        predict_data = json.load(f)
+
+    # register metadata
+    json_data: dict[str, typing.Any] = execution_metadata()
+    json_data.update(
+        dict(
+            predictions=str(predictions),
+            output_folder=str(output_folder),
+            threshold=threshold,
+            metric=metric,
+            steps=steps,
+        ),
+    )
+    json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
+    save_json_with_backup(evaluation_file.with_suffix(".meta.json"), json_data)
+
+    threshold = validate_threshold(threshold, predict_data)
+    threshold_list = numpy.arange(
+        0.0, (1.0 + 1 / steps), 1 / steps, dtype=numpy.float64
+    )
+
+    if isinstance(threshold, str):
+        # Compute threshold on specified split, if required
+        logger.info(f"Evaluating threshold on `{threshold}` split using " f"`{metric}`")
+        counts = load_count(predictions.parent, predict_data[threshold], threshold_list)
+        metric_list = compute_metric(
+            counts, name2metric(typing.cast(SUPPORTED_METRIC_TYPE, metric))
+        )
+        threshold_index = metric_list.argmax()
+        logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}")
+
+    else:
+        # must figure out the closest threshold from the list we are using
+        threshold_index = (numpy.abs(threshold_list - threshold)).argmin()
+        logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}")
+
+    # create visualisations
+    for split_name, sample_list in predict_data.items():
+        logger.info(
+            f"Creating {len(sample_list)} visualisations for split `{split_name}`"
+        )
+        for sample in tqdm.tqdm(sample_list):
+            image = view(
+                predictions.parent,
+                sample[1],
+                threshold=threshold_list[threshold_index],
+                show_errors=show_errors,
+                tp_color=(255, 255, 255),
+                fp_color=(255, 0, 0),
+                fn_color=(0, 255, 0),
+                alpha=alpha,
+            )
+            dest = (output_folder / sample[1]).with_suffix(".png")
+            tqdm.tqdm.write(f"{sample[1]} -> {dest}")
+            image.save(dest)
-- 
GitLab