Skip to content
Snippets Groups Projects
Commit 9a399fad authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation.view] Add view script

parent afaf0cfe
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -17,6 +17,7 @@ from . import ( ...@@ -17,6 +17,7 @@ from . import (
# mkmask, # mkmask,
# significance, # significance,
train, train,
view,
) )
...@@ -44,6 +45,7 @@ segmentation.add_command( ...@@ -44,6 +45,7 @@ segmentation.add_command(
package=__name__, package=__name__,
).train_analysis, ).train_analysis,
) )
segmentation.add_command(view.view)
segmentation.add_command( segmentation.add_command(
importlib.import_module("..experiment", package=__name__).experiment, importlib.import_module("..experiment", package=__name__).experiment,
) )
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import click
import h5py
import PIL.Image
import torch
from clapper.click import ResourceOption
from mednet.libs.common.scripts.click import ConfigCommand
from PIL import ImageColor
from PIL.ImageChops import invert, logical_and
from torchvision.transforms.functional import to_pil_image
def get_tp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
"""Compute the true positive mask.
Parameters
----------
binary_image
B/W image to compare to the target.
binary_target
B/W reference image.
Returns
-------
Image with white pixels where both image and target are white, black pixels otherwise.
"""
return logical_and(binary_image, binary_target)
def get_tn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
"""Compute the false positive mask.
Parameters
----------
binary_image
B/W image to compare to the target.
binary_target
B/W reference image.
Returns
-------
Image with white pixels where both image and target are black, black pixels otherwise.
"""
return logical_and(invert(binary_image), invert(binary_target))
def get_fp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
"""Compute the true positive mask.
Parameters
----------
binary_image
B/W image to compare to the target.
binary_target
B/W reference image.
Returns
-------
Image with white pixels where image is white and target is black, black pixels otherwise.
"""
return logical_and(binary_image, invert(binary_target))
def get_fn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
"""Compute the true positive mask.
Parameters
----------
binary_image
B/W image to compare to the target.
binary_target
B/W reference image.
Returns
-------
Image with white pixels where image is black and target is white, black pixels otherwise.
"""
return logical_and(invert(binary_image), binary_target)
def get_masks(
binary_prediction_image: PIL.Image, binary_target_image: PIL.Image
) -> tuple[PIL.Image.Image, PIL.Image, PIL.Image, PIL.Image]:
"""Given a B/W binary image and a target, return the tp, tn, fp, fn masks.
Parameters
----------
binary_prediction_image
B/W image.
binary_target_image
B/W reference image.
Returns
-------
The tp, tn, fp, fn masks
"""
tp_mask = get_tp_mask(binary_prediction_image, binary_target_image)
tn_mask = get_tn_mask(binary_prediction_image, binary_target_image)
fp_mask = get_fp_mask(binary_prediction_image, binary_target_image)
fn_mask = get_fn_mask(binary_prediction_image, binary_target_image)
return tp_mask, tn_mask, fp_mask, fn_mask
def load_image_from_hdf5(filepath: pathlib.Path, key: str) -> PIL.Image:
"""Load an image located in an hdf5 file by key.
Parameters
----------
filepath
Path to an hdf5 file.
key
Key to search for in the hdf5 file.
Returns
-------
The loaded PIL.Image.
"""
with h5py.File(filepath, "r") as f:
img = to_pil_image(torch.from_numpy(f.get(key)[:]))
return img # noqa: RET504
def image_to_binary(image: PIL.Image, threshold: int = 127) -> PIL.Image:
"""Change the mode of a PIL image to '1' (binary) using a threshold.
Parameters
----------
image
The image to convert to binary mode.
threshold
The threshold to use for convertion, with p <= threshold = black, p > threshold = white.
Returns
-------
The binary image.
"""
image = image.point(lambda p: 255 if p > threshold else 0)
return image.convert("1")
def color_with_mask(image: PIL.Image, mask: PIL.Image, color: ImageColor) -> PIL.Image:
"""Colorize the image with a given color by using a mask.
Parameters
----------
image
The image to colorize.
mask
Mask used to indicate where to apply the color.
color
The color to apply.
Returns
-------
The colorized image.
"""
color_plane = PIL.Image.new(mode="RGB", size=image.size, color=color)
image = image.convert("RGB")
image.paste(color_plane, mask)
return image
@click.command(
entry_point_group="mednet.libs.segmentation.config",
cls=ConfigCommand,
epilog="""Examples:
\b
1. Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized:
.. code:: sh
$ mednet segmentation view -f results/predictions/test/test.hdf5 -o colorized_prediction.png
""",
)
@click.option(
"--hdf5-file",
"-f",
help="File in which predictions are currently stored",
required=True,
type=click.Path(
file_okay=True,
dir_okay=False,
writable=True,
path_type=pathlib.Path,
),
cls=ResourceOption,
)
@click.option(
"--output-file",
"-o",
help="File in which to store the result (created if does not exist)",
required=True,
type=click.Path(
file_okay=True,
dir_okay=False,
writable=True,
path_type=pathlib.Path,
),
default="segmentation.png",
cls=ResourceOption,
)
def view(
hdf5_file: pathlib.Path,
output_file: pathlib.Path,
**_, # ignored
): # numpydoc ignore=PR01
"""Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized."""
colors_dict = {
"tp": ImageColor.getcolor("white", "RGB"),
"tn": ImageColor.getcolor("black", "RGB"),
"fp": ImageColor.getcolor("green", "RGB"),
"fn": ImageColor.getcolor("red", "RGB"),
}
pred_img = load_image_from_hdf5(hdf5_file, "img")
binary_pred_img = image_to_binary(pred_img, 127)
target_img = load_image_from_hdf5(hdf5_file, "target")
binary_target_img = image_to_binary(target_img, 127)
tp_mask, tn_mask, fp_mask, fn_mask = get_masks(binary_pred_img, binary_target_img)
colorized_image = pred_img
colorized_image = color_with_mask(colorized_image, tp_mask, colors_dict["tp"])
colorized_image = color_with_mask(colorized_image, tn_mask, colors_dict["tn"])
colorized_image = color_with_mask(colorized_image, fp_mask, colors_dict["fp"])
colorized_image = color_with_mask(colorized_image, fn_mask, colors_dict["fn"])
colorized_image.save(output_file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment