From 9a399fad93c0bee944bcc962c2f60cf64216eadc Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 19 Jun 2024 17:16:44 +0200 Subject: [PATCH] [segmentation.view] Add view script --- src/mednet/libs/segmentation/scripts/cli.py | 2 + src/mednet/libs/segmentation/scripts/view.py | 239 +++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 src/mednet/libs/segmentation/scripts/view.py diff --git a/src/mednet/libs/segmentation/scripts/cli.py b/src/mednet/libs/segmentation/scripts/cli.py index a4a7730a..2331fd13 100644 --- a/src/mednet/libs/segmentation/scripts/cli.py +++ b/src/mednet/libs/segmentation/scripts/cli.py @@ -17,6 +17,7 @@ from . import ( # mkmask, # significance, train, + view, ) @@ -44,6 +45,7 @@ segmentation.add_command( package=__name__, ).train_analysis, ) +segmentation.add_command(view.view) segmentation.add_command( importlib.import_module("..experiment", package=__name__).experiment, ) diff --git a/src/mednet/libs/segmentation/scripts/view.py b/src/mednet/libs/segmentation/scripts/view.py new file mode 100644 index 00000000..30c3b355 --- /dev/null +++ b/src/mednet/libs/segmentation/scripts/view.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pathlib + +import click +import h5py +import PIL.Image +import torch +from clapper.click import ResourceOption +from mednet.libs.common.scripts.click import ConfigCommand +from PIL import ImageColor +from PIL.ImageChops import invert, logical_and +from torchvision.transforms.functional import to_pil_image + + +def get_tp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image: + """Compute the true positive mask. + + Parameters + ---------- + binary_image + B/W image to compare to the target. + binary_target + B/W reference image. + + Returns + ------- + Image with white pixels where both image and target are white, black pixels otherwise. + """ + return logical_and(binary_image, binary_target) + + +def get_tn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image: + """Compute the false positive mask. + + Parameters + ---------- + binary_image + B/W image to compare to the target. + binary_target + B/W reference image. + + Returns + ------- + Image with white pixels where both image and target are black, black pixels otherwise. + """ + return logical_and(invert(binary_image), invert(binary_target)) + + +def get_fp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image: + """Compute the true positive mask. + + Parameters + ---------- + binary_image + B/W image to compare to the target. + binary_target + B/W reference image. + + Returns + ------- + Image with white pixels where image is white and target is black, black pixels otherwise. + """ + return logical_and(binary_image, invert(binary_target)) + + +def get_fn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image: + """Compute the true positive mask. + + Parameters + ---------- + binary_image + B/W image to compare to the target. + binary_target + B/W reference image. + + Returns + ------- + Image with white pixels where image is black and target is white, black pixels otherwise. + """ + return logical_and(invert(binary_image), binary_target) + + +def get_masks( + binary_prediction_image: PIL.Image, binary_target_image: PIL.Image +) -> tuple[PIL.Image.Image, PIL.Image, PIL.Image, PIL.Image]: + """Given a B/W binary image and a target, return the tp, tn, fp, fn masks. + + Parameters + ---------- + binary_prediction_image + B/W image. + binary_target_image + B/W reference image. + + Returns + ------- + The tp, tn, fp, fn masks + """ + tp_mask = get_tp_mask(binary_prediction_image, binary_target_image) + tn_mask = get_tn_mask(binary_prediction_image, binary_target_image) + fp_mask = get_fp_mask(binary_prediction_image, binary_target_image) + fn_mask = get_fn_mask(binary_prediction_image, binary_target_image) + + return tp_mask, tn_mask, fp_mask, fn_mask + + +def load_image_from_hdf5(filepath: pathlib.Path, key: str) -> PIL.Image: + """Load an image located in an hdf5 file by key. + + Parameters + ---------- + filepath + Path to an hdf5 file. + key + Key to search for in the hdf5 file. + + Returns + ------- + The loaded PIL.Image. + """ + with h5py.File(filepath, "r") as f: + img = to_pil_image(torch.from_numpy(f.get(key)[:])) + + return img # noqa: RET504 + + +def image_to_binary(image: PIL.Image, threshold: int = 127) -> PIL.Image: + """Change the mode of a PIL image to '1' (binary) using a threshold. + + Parameters + ---------- + image + The image to convert to binary mode. + threshold + The threshold to use for convertion, with p <= threshold = black, p > threshold = white. + + Returns + ------- + The binary image. + """ + image = image.point(lambda p: 255 if p > threshold else 0) + return image.convert("1") + + +def color_with_mask(image: PIL.Image, mask: PIL.Image, color: ImageColor) -> PIL.Image: + """Colorize the image with a given color by using a mask. + + Parameters + ---------- + image + The image to colorize. + mask + Mask used to indicate where to apply the color. + color + The color to apply. + + Returns + ------- + The colorized image. + """ + color_plane = PIL.Image.new(mode="RGB", size=image.size, color=color) + + image = image.convert("RGB") + image.paste(color_plane, mask) + return image + + +@click.command( + entry_point_group="mednet.libs.segmentation.config", + cls=ConfigCommand, + epilog="""Examples: + +\b + 1. Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized: + + .. code:: sh + + $ mednet segmentation view -f results/predictions/test/test.hdf5 -o colorized_prediction.png + +""", +) +@click.option( + "--hdf5-file", + "-f", + help="File in which predictions are currently stored", + required=True, + type=click.Path( + file_okay=True, + dir_okay=False, + writable=True, + path_type=pathlib.Path, + ), + cls=ResourceOption, +) +@click.option( + "--output-file", + "-o", + help="File in which to store the result (created if does not exist)", + required=True, + type=click.Path( + file_okay=True, + dir_okay=False, + writable=True, + path_type=pathlib.Path, + ), + default="segmentation.png", + cls=ResourceOption, +) +def view( + hdf5_file: pathlib.Path, + output_file: pathlib.Path, + **_, # ignored +): # numpydoc ignore=PR01 + """Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized.""" + colors_dict = { + "tp": ImageColor.getcolor("white", "RGB"), + "tn": ImageColor.getcolor("black", "RGB"), + "fp": ImageColor.getcolor("green", "RGB"), + "fn": ImageColor.getcolor("red", "RGB"), + } + + pred_img = load_image_from_hdf5(hdf5_file, "img") + binary_pred_img = image_to_binary(pred_img, 127) + + target_img = load_image_from_hdf5(hdf5_file, "target") + binary_target_img = image_to_binary(target_img, 127) + + tp_mask, tn_mask, fp_mask, fn_mask = get_masks(binary_pred_img, binary_target_img) + + colorized_image = pred_img + colorized_image = color_with_mask(colorized_image, tp_mask, colors_dict["tp"]) + colorized_image = color_with_mask(colorized_image, tn_mask, colors_dict["tn"]) + colorized_image = color_with_mask(colorized_image, fp_mask, colors_dict["fp"]) + colorized_image = color_with_mask(colorized_image, fn_mask, colors_dict["fn"]) + + colorized_image.save(output_file) -- GitLab