From 9a399fad93c0bee944bcc962c2f60cf64216eadc Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 19 Jun 2024 17:16:44 +0200
Subject: [PATCH] [segmentation.view] Add view script

---
 src/mednet/libs/segmentation/scripts/cli.py  |   2 +
 src/mednet/libs/segmentation/scripts/view.py | 239 +++++++++++++++++++
 2 files changed, 241 insertions(+)
 create mode 100644 src/mednet/libs/segmentation/scripts/view.py

diff --git a/src/mednet/libs/segmentation/scripts/cli.py b/src/mednet/libs/segmentation/scripts/cli.py
index a4a7730a..2331fd13 100644
--- a/src/mednet/libs/segmentation/scripts/cli.py
+++ b/src/mednet/libs/segmentation/scripts/cli.py
@@ -17,6 +17,7 @@ from . import (
     # mkmask,
     # significance,
     train,
+    view,
 )
 
 
@@ -44,6 +45,7 @@ segmentation.add_command(
         package=__name__,
     ).train_analysis,
 )
+segmentation.add_command(view.view)
 segmentation.add_command(
     importlib.import_module("..experiment", package=__name__).experiment,
 )
diff --git a/src/mednet/libs/segmentation/scripts/view.py b/src/mednet/libs/segmentation/scripts/view.py
new file mode 100644
index 00000000..30c3b355
--- /dev/null
+++ b/src/mednet/libs/segmentation/scripts/view.py
@@ -0,0 +1,239 @@
+# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import pathlib
+
+import click
+import h5py
+import PIL.Image
+import torch
+from clapper.click import ResourceOption
+from mednet.libs.common.scripts.click import ConfigCommand
+from PIL import ImageColor
+from PIL.ImageChops import invert, logical_and
+from torchvision.transforms.functional import to_pil_image
+
+
+def get_tp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
+    """Compute the true positive mask.
+
+    Parameters
+    ----------
+    binary_image
+        B/W image to compare to the target.
+    binary_target
+        B/W reference image.
+
+    Returns
+    -------
+        Image with white pixels where both image and target are white, black pixels otherwise.
+    """
+    return logical_and(binary_image, binary_target)
+
+
+def get_tn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
+    """Compute the false positive mask.
+
+    Parameters
+    ----------
+    binary_image
+        B/W image to compare to the target.
+    binary_target
+        B/W reference image.
+
+    Returns
+    -------
+        Image with white pixels where both image and target are black, black pixels otherwise.
+    """
+    return logical_and(invert(binary_image), invert(binary_target))
+
+
+def get_fp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
+    """Compute the true positive mask.
+
+    Parameters
+    ----------
+    binary_image
+        B/W image to compare to the target.
+    binary_target
+        B/W reference image.
+
+    Returns
+    -------
+        Image with white pixels where image is white and target is black, black pixels otherwise.
+    """
+    return logical_and(binary_image, invert(binary_target))
+
+
+def get_fn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
+    """Compute the true positive mask.
+
+    Parameters
+    ----------
+    binary_image
+        B/W image to compare to the target.
+    binary_target
+        B/W reference image.
+
+    Returns
+    -------
+        Image with white pixels where image is black and target is white, black pixels otherwise.
+    """
+    return logical_and(invert(binary_image), binary_target)
+
+
+def get_masks(
+    binary_prediction_image: PIL.Image, binary_target_image: PIL.Image
+) -> tuple[PIL.Image.Image, PIL.Image, PIL.Image, PIL.Image]:
+    """Given a B/W binary image and a target, return the tp, tn, fp, fn masks.
+
+    Parameters
+    ----------
+    binary_prediction_image
+        B/W image.
+    binary_target_image
+        B/W reference image.
+
+    Returns
+    -------
+        The tp, tn, fp, fn masks
+    """
+    tp_mask = get_tp_mask(binary_prediction_image, binary_target_image)
+    tn_mask = get_tn_mask(binary_prediction_image, binary_target_image)
+    fp_mask = get_fp_mask(binary_prediction_image, binary_target_image)
+    fn_mask = get_fn_mask(binary_prediction_image, binary_target_image)
+
+    return tp_mask, tn_mask, fp_mask, fn_mask
+
+
+def load_image_from_hdf5(filepath: pathlib.Path, key: str) -> PIL.Image:
+    """Load an image located in an hdf5 file by key.
+
+    Parameters
+    ----------
+    filepath
+        Path to an hdf5 file.
+    key
+        Key to search for in the hdf5 file.
+
+    Returns
+    -------
+        The loaded PIL.Image.
+    """
+    with h5py.File(filepath, "r") as f:
+        img = to_pil_image(torch.from_numpy(f.get(key)[:]))
+
+    return img  # noqa: RET504
+
+
+def image_to_binary(image: PIL.Image, threshold: int = 127) -> PIL.Image:
+    """Change the mode of a PIL image to '1' (binary) using a threshold.
+
+    Parameters
+    ----------
+    image
+        The image to convert to binary mode.
+    threshold
+        The threshold to use for convertion, with p <= threshold = black, p > threshold = white.
+
+    Returns
+    -------
+        The binary image.
+    """
+    image = image.point(lambda p: 255 if p > threshold else 0)
+    return image.convert("1")
+
+
+def color_with_mask(image: PIL.Image, mask: PIL.Image, color: ImageColor) -> PIL.Image:
+    """Colorize the image with a given color by using a mask.
+
+    Parameters
+    ----------
+    image
+        The image to colorize.
+    mask
+        Mask used to indicate where to apply the color.
+    color
+        The color to apply.
+
+    Returns
+    -------
+        The colorized image.
+    """
+    color_plane = PIL.Image.new(mode="RGB", size=image.size, color=color)
+
+    image = image.convert("RGB")
+    image.paste(color_plane, mask)
+    return image
+
+
+@click.command(
+    entry_point_group="mednet.libs.segmentation.config",
+    cls=ConfigCommand,
+    epilog="""Examples:
+
+\b
+  1. Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized:
+
+     .. code:: sh
+
+        $ mednet segmentation view -f results/predictions/test/test.hdf5 -o colorized_prediction.png
+
+""",
+)
+@click.option(
+    "--hdf5-file",
+    "-f",
+    help="File in which predictions are currently stored",
+    required=True,
+    type=click.Path(
+        file_okay=True,
+        dir_okay=False,
+        writable=True,
+        path_type=pathlib.Path,
+    ),
+    cls=ResourceOption,
+)
+@click.option(
+    "--output-file",
+    "-o",
+    help="File in which to store the result (created if does not exist)",
+    required=True,
+    type=click.Path(
+        file_okay=True,
+        dir_okay=False,
+        writable=True,
+        path_type=pathlib.Path,
+    ),
+    default="segmentation.png",
+    cls=ResourceOption,
+)
+def view(
+    hdf5_file: pathlib.Path,
+    output_file: pathlib.Path,
+    **_,  # ignored
+):  # numpydoc ignore=PR01
+    """Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized."""
+    colors_dict = {
+        "tp": ImageColor.getcolor("white", "RGB"),
+        "tn": ImageColor.getcolor("black", "RGB"),
+        "fp": ImageColor.getcolor("green", "RGB"),
+        "fn": ImageColor.getcolor("red", "RGB"),
+    }
+
+    pred_img = load_image_from_hdf5(hdf5_file, "img")
+    binary_pred_img = image_to_binary(pred_img, 127)
+
+    target_img = load_image_from_hdf5(hdf5_file, "target")
+    binary_target_img = image_to_binary(target_img, 127)
+
+    tp_mask, tn_mask, fp_mask, fn_mask = get_masks(binary_pred_img, binary_target_img)
+
+    colorized_image = pred_img
+    colorized_image = color_with_mask(colorized_image, tp_mask, colors_dict["tp"])
+    colorized_image = color_with_mask(colorized_image, tn_mask, colors_dict["tn"])
+    colorized_image = color_with_mask(colorized_image, fp_mask, colors_dict["fp"])
+    colorized_image = color_with_mask(colorized_image, fn_mask, colors_dict["fn"])
+
+    colorized_image.save(output_file)
-- 
GitLab