Skip to content
Snippets Groups Projects
Commit f869f81e authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[libs.segmentation.scripts.view] Implement new view script using the output of...

[libs.segmentation.scripts.view] Implement new view script using the output of the predict step; Implement colorization/transparency options
parent 653ec39e
No related branches found
No related tags found
1 merge request!46Create common library
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import h5py
import numpy
import numpy.typing
import PIL.Image
import PIL.ImageOps
import torch
import torchvision.transforms.functional
from .evaluator import tfpn_masks
def view(
basedir: pathlib.Path,
stem: str,
threshold: float,
show_errors: bool,
tp_color: tuple[int, int, int],
fp_color: tuple[int, int, int],
fn_color: tuple[int, int, int],
alpha: float,
) -> PIL.Image.Image:
"""Create an segmentation map visualisation.
Parameters
----------
basedir
Base directory where the prediction indicated by ``stem`` is stored.
stem
Name of the HDF5 file containing the predictions, as output by the
``predict`` CLI.
threshold
The threshold to apply to the probability map loaded from the HDF5
file.
show_errors
If set to ``True``, then colours false-positives (in red), and false
negatives (in green).
tp_color
Tuple that indicates which color to use for displaying true-positives.
fp_color
Tuple that indicates which color to use for displaying false-positives.
fn_color
Tuple that indicates which color to use for displaying false-negatives.
alpha
How transparent will the overlay be.
Returns
-------
An image with an overlayed segmentation map that can be saved or
displayed.
"""
def _to_pil(arr: numpy.typing.NDArray[numpy.float32]) -> PIL.Image.Image:
return torchvision.transforms.functional.to_pil_image(torch.Tensor(arr))
with h5py.File(basedir / stem, "r") as f:
image: numpy.typing.NDArray[numpy.float32] = numpy.array(f.get("image"))
pred: numpy.typing.NDArray[numpy.float32] = numpy.array(f.get("prediction"))
target: numpy.typing.NDArray[numpy.bool_] = numpy.array(f.get("target"))
mask: numpy.typing.NDArray[numpy.bool_] = numpy.array(f.get("mask"))
image *= mask
pred *= mask
target = numpy.logical_and(target, mask)
if show_errors:
tp, fp, _, fn = tfpn_masks(pred, target, threshold)
# change to PIL representation
tp_pil = _to_pil(tp.astype(numpy.float32))
tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0, 0, 0), tp_color)
fp_pil = _to_pil(fp.astype(numpy.float32))
fp_pil_colored = PIL.ImageOps.colorize(fp_pil, (0, 0, 0), fp_color)
fn_pil = _to_pil(fn.astype(numpy.float32))
fn_pil_colored = PIL.ImageOps.colorize(fn_pil, (0, 0, 0), fn_color)
tp_pil_colored.paste(fp_pil_colored, mask=fp_pil)
tp_pil_colored.paste(fn_pil_colored, mask=fn_pil)
else:
overlay = pred >= threshold
tp_pil = _to_pil(overlay.astype(numpy.float32))
tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0, 0, 0), tp_color)
retval = _to_pil(image)
return PIL.Image.blend(retval, tp_pil_colored, alpha)
...@@ -9,13 +9,10 @@ from clapper.click import AliasedGroup ...@@ -9,13 +9,10 @@ from clapper.click import AliasedGroup
from . import ( from . import (
# analyze, # analyze,
# compare,
config, config,
database, database,
evaluate, evaluate,
predict, predict,
# mkmask,
# significance,
train, train,
view, view,
) )
...@@ -31,11 +28,8 @@ def segmentation(): ...@@ -31,11 +28,8 @@ def segmentation():
# segmentation.add_command(analyze.analyze) # segmentation.add_command(analyze.analyze)
# segmentation.add_command(compare.compare)
segmentation.add_command(config.config) segmentation.add_command(config.config)
segmentation.add_command(database.database) segmentation.add_command(database.database)
# segmentation.add_command(mkmask.mkmask)
# segmentation.add_command(significance.significance)
segmentation.add_command(train.train) segmentation.add_command(train.train)
segmentation.add_command(predict.predict) segmentation.add_command(predict.predict)
segmentation.add_command(evaluate.evaluate) segmentation.add_command(evaluate.evaluate)
......
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import json
import pathlib import pathlib
import typing
import click import click
import h5py import tqdm
import PIL.Image from clapper.click import ResourceOption, verbosity_option
import torch from clapper.logging import setup
from clapper.click import ResourceOption
from mednet.libs.common.scripts.click import ConfigCommand from mednet.libs.common.scripts.click import ConfigCommand
from PIL import ImageColor from mednet.libs.segmentation.engine.evaluator import SUPPORTED_METRIC_TYPE
from PIL.ImageChops import invert, logical_and
from torchvision.transforms.functional import to_pil_image from .evaluate import validate_threshold
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def get_tp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
"""Compute the true positive mask.
Parameters
----------
binary_image
B/W image to compare to the target.
binary_target
B/W reference image.
Returns
-------
Image with white pixels where both image and target are white, black pixels otherwise.
"""
return logical_and(binary_image, binary_target)
def get_tn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
"""Compute the false positive mask.
Parameters
----------
binary_image
B/W image to compare to the target.
binary_target
B/W reference image.
Returns
-------
Image with white pixels where both image and target are black, black pixels otherwise.
"""
return logical_and(invert(binary_image), invert(binary_target))
def get_fp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
"""Compute the true positive mask.
Parameters
----------
binary_image
B/W image to compare to the target.
binary_target
B/W reference image.
Returns
-------
Image with white pixels where image is white and target is black, black pixels otherwise.
"""
return logical_and(binary_image, invert(binary_target))
def get_fn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image:
"""Compute the true positive mask.
Parameters
----------
binary_image
B/W image to compare to the target.
binary_target
B/W reference image.
Returns
-------
Image with white pixels where image is black and target is white, black pixels otherwise.
"""
return logical_and(invert(binary_image), binary_target)
def get_masks(
binary_prediction_image: PIL.Image, binary_target_image: PIL.Image
) -> tuple[PIL.Image.Image, PIL.Image, PIL.Image, PIL.Image]:
"""Given a B/W binary image and a target, return the tp, tn, fp, fn masks.
Parameters
----------
binary_prediction_image
B/W image.
binary_target_image
B/W reference image.
Returns
-------
The tp, tn, fp, fn masks
"""
tp_mask = get_tp_mask(binary_prediction_image, binary_target_image)
tn_mask = get_tn_mask(binary_prediction_image, binary_target_image)
fp_mask = get_fp_mask(binary_prediction_image, binary_target_image)
fn_mask = get_fn_mask(binary_prediction_image, binary_target_image)
return tp_mask, tn_mask, fp_mask, fn_mask
def load_image_from_hdf5(filepath: pathlib.Path, key: str) -> PIL.Image:
"""Load an image located in an hdf5 file by key.
Parameters
----------
filepath
Path to an hdf5 file.
key
Key to search for in the hdf5 file.
Returns
-------
The loaded PIL.Image.
"""
with h5py.File(filepath, "r") as f:
img = to_pil_image(torch.from_numpy(f.get(key)[:]))
return img # noqa: RET504
def image_to_binary(image: PIL.Image, threshold: int = 127) -> PIL.Image:
"""Change the mode of a PIL image to '1' (binary) using a threshold.
Parameters
----------
image
The image to convert to binary mode.
threshold
The threshold to use for convertion, with p <= threshold = black, p > threshold = white.
Returns
-------
The binary image.
"""
image = image.point(lambda p: 255 if p > threshold else 0)
return image.convert("1")
def color_with_mask(image: PIL.Image, mask: PIL.Image, color: ImageColor) -> PIL.Image:
"""Colorize the image with a given color by using a mask.
Parameters
----------
image
The image to colorize.
mask
Mask used to indicate where to apply the color.
color
The color to apply.
Returns
-------
The colorized image.
"""
color_plane = PIL.Image.new(mode="RGB", size=image.size, color=color)
image = image.convert("RGB")
image.paste(color_plane, mask)
return image
@click.command( @click.command(
...@@ -174,66 +24,184 @@ def color_with_mask(image: PIL.Image, mask: PIL.Image, color: ImageColor) -> PIL ...@@ -174,66 +24,184 @@ def color_with_mask(image: PIL.Image, mask: PIL.Image, color: ImageColor) -> PIL
epilog="""Examples: epilog="""Examples:
\b \b
1. Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized: 1. Runs evaluation on an existing dataset configuration:
.. code:: sh .. code:: sh
$ mednet segmentation view -f results/predictions/test/test.hdf5 -o colorized_prediction.png $ mednet segmentation view -vv --predictions=path/to/predictions.json --output-folder=path/to/results
""", """,
) )
@click.option( @click.option(
"--hdf5-file", "--predictions",
"-f", "-p",
help="File in which predictions are currently stored", help="""Path to the JSON file describing available predictions. The actual
predictions are supposed to lie on the same folder.""",
required=True, required=True,
type=click.Path( type=click.Path(
file_okay=True, file_okay=True,
dir_okay=False, dir_okay=False,
writable=True, writable=False,
path_type=pathlib.Path, path_type=pathlib.Path,
), ),
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--output-file", "--output-folder",
"-o", "-o",
help="File in which to store the result (created if does not exist)", help="Directory in which to store results (created if does not exist)",
required=True, required=True,
type=click.Path( type=click.Path(
file_okay=True, file_okay=False,
dir_okay=False, dir_okay=True,
writable=True, writable=True,
path_type=pathlib.Path, path_type=pathlib.Path,
), ),
default="segmentation.png", default="results",
cls=ResourceOption,
)
@click.option(
"--threshold",
"-t",
help="""This number is used to define positives and negatives from
probability maps, and used to report metrics based on a threshold chosen *a
priori*. It can be set to a floating-point value, or to the name of dataset
split in ``--predictions``.
""",
default="0.5",
show_default=True,
required=False,
cls=ResourceOption,
)
@click.option(
"--metric",
"-m",
help="""If threshold is set to the name of a split in ``--predictions``,
then this parameter defines the metric function to be used to evaluate the
threshold at which the metric reaches its maximum value. All other splits
are evaluated with respect to this threshold.""",
default="f1",
type=click.Choice(typing.get_args(SUPPORTED_METRIC_TYPE), case_sensitive=True),
show_default=True,
required=True,
cls=ResourceOption,
)
@click.option(
"--steps",
"-s",
help="""Number of steps for evaluating metrics on various splits. This
value is used when drawing precision-recall plots, or when deciding the
highest metric value on splits.""",
default=100,
type=click.IntRange(10),
show_default=True,
required=True,
cls=ResourceOption,
)
@click.option(
"--show-errors/--no-show-errors",
"-e/-E",
help="""If set, then shows a colorized version of the segmentation map in
which false-positives are marked in red, and false-negatives in green.
True positives are always marked in white.""",
default=False,
show_default=True,
required=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option(
"--alpha",
"-a",
help="""Defines the transparency weighting between the original image and
the predicted segmentation maps. A value of 1.0 makes the program output
only segmentation maps. A value of 0.0 makes the program output only the
processed image.""",
default=0.6,
type=click.FloatRange(0.0, 1.0),
show_default=True,
required=True,
cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def view( def view(
hdf5_file: pathlib.Path, predictions: pathlib.Path,
output_file: pathlib.Path, output_folder: pathlib.Path,
threshold: str | float,
metric: str,
steps: int,
show_errors: bool,
alpha: float,
**_, # ignored **_, # ignored
): # numpydoc ignore=PR01 ): # numpydoc ignore=PR01
"""Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized.""" """Evaluate predictions (from a model) on a segmentation task."""
colors_dict = {
"tp": ImageColor.getcolor("white", "RGB"), import numpy
"tn": ImageColor.getcolor("black", "RGB"), from mednet.libs.common.scripts.utils import (
"fp": ImageColor.getcolor("green", "RGB"), execution_metadata,
"fn": ImageColor.getcolor("red", "RGB"), save_json_with_backup,
} )
from mednet.libs.segmentation.engine.evaluator import (
pred_img = load_image_from_hdf5(hdf5_file, "img") compute_metric,
binary_pred_img = image_to_binary(pred_img, 127) load_count,
name2metric,
target_img = load_image_from_hdf5(hdf5_file, "target") )
binary_target_img = image_to_binary(target_img, 127) from mednet.libs.segmentation.engine.viewer import view
tp_mask, tn_mask, fp_mask, fn_mask = get_masks(binary_pred_img, binary_target_img) evaluation_filename = "evaluation.json"
evaluation_file = output_folder / evaluation_filename
colorized_image = pred_img
colorized_image = color_with_mask(colorized_image, tp_mask, colors_dict["tp"]) with predictions.open("r") as f:
colorized_image = color_with_mask(colorized_image, tn_mask, colors_dict["tn"]) predict_data = json.load(f)
colorized_image = color_with_mask(colorized_image, fp_mask, colors_dict["fp"])
colorized_image = color_with_mask(colorized_image, fn_mask, colors_dict["fn"]) # register metadata
json_data: dict[str, typing.Any] = execution_metadata()
colorized_image.save(output_file) json_data.update(
dict(
predictions=str(predictions),
output_folder=str(output_folder),
threshold=threshold,
metric=metric,
steps=steps,
),
)
json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
save_json_with_backup(evaluation_file.with_suffix(".meta.json"), json_data)
threshold = validate_threshold(threshold, predict_data)
threshold_list = numpy.arange(
0.0, (1.0 + 1 / steps), 1 / steps, dtype=numpy.float64
)
if isinstance(threshold, str):
# Compute threshold on specified split, if required
logger.info(f"Evaluating threshold on `{threshold}` split using " f"`{metric}`")
counts = load_count(predictions.parent, predict_data[threshold], threshold_list)
metric_list = compute_metric(
counts, name2metric(typing.cast(SUPPORTED_METRIC_TYPE, metric))
)
threshold_index = metric_list.argmax()
logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}")
else:
# must figure out the closest threshold from the list we are using
threshold_index = (numpy.abs(threshold_list - threshold)).argmin()
logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}")
# create visualisations
for split_name, sample_list in predict_data.items():
logger.info(
f"Creating {len(sample_list)} visualisations for split `{split_name}`"
)
for sample in tqdm.tqdm(sample_list):
image = view(
predictions.parent,
sample[1],
threshold=threshold_list[threshold_index],
show_errors=show_errors,
tp_color=(255, 255, 255),
fp_color=(255, 0, 0),
fn_color=(0, 255, 0),
alpha=alpha,
)
dest = (output_folder / sample[1]).with_suffix(".png")
tqdm.tqdm.write(f"{sample[1]} -> {dest}")
image.save(dest)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment