From f869f81eeba1faa55cdfe998b10cb3e965cdcbb9 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Tue, 25 Jun 2024 15:52:46 +0200 Subject: [PATCH] [libs.segmentation.scripts.view] Implement new view script using the output of the predict step; Implement colorization/transparency options --- src/mednet/libs/segmentation/engine/viewer.py | 93 +++++ src/mednet/libs/segmentation/scripts/cli.py | 6 - src/mednet/libs/segmentation/scripts/view.py | 364 ++++++++---------- 3 files changed, 259 insertions(+), 204 deletions(-) create mode 100644 src/mednet/libs/segmentation/engine/viewer.py diff --git a/src/mednet/libs/segmentation/engine/viewer.py b/src/mednet/libs/segmentation/engine/viewer.py new file mode 100644 index 00000000..6699f683 --- /dev/null +++ b/src/mednet/libs/segmentation/engine/viewer.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pathlib + +import h5py +import numpy +import numpy.typing +import PIL.Image +import PIL.ImageOps +import torch +import torchvision.transforms.functional + +from .evaluator import tfpn_masks + + +def view( + basedir: pathlib.Path, + stem: str, + threshold: float, + show_errors: bool, + tp_color: tuple[int, int, int], + fp_color: tuple[int, int, int], + fn_color: tuple[int, int, int], + alpha: float, +) -> PIL.Image.Image: + """Create an segmentation map visualisation. + + Parameters + ---------- + basedir + Base directory where the prediction indicated by ``stem`` is stored. + stem + Name of the HDF5 file containing the predictions, as output by the + ``predict`` CLI. + threshold + The threshold to apply to the probability map loaded from the HDF5 + file. + show_errors + If set to ``True``, then colours false-positives (in red), and false + negatives (in green). + tp_color + Tuple that indicates which color to use for displaying true-positives. + fp_color + Tuple that indicates which color to use for displaying false-positives. + fn_color + Tuple that indicates which color to use for displaying false-negatives. + alpha + How transparent will the overlay be. + + Returns + ------- + An image with an overlayed segmentation map that can be saved or + displayed. + """ + + def _to_pil(arr: numpy.typing.NDArray[numpy.float32]) -> PIL.Image.Image: + return torchvision.transforms.functional.to_pil_image(torch.Tensor(arr)) + + with h5py.File(basedir / stem, "r") as f: + image: numpy.typing.NDArray[numpy.float32] = numpy.array(f.get("image")) + pred: numpy.typing.NDArray[numpy.float32] = numpy.array(f.get("prediction")) + target: numpy.typing.NDArray[numpy.bool_] = numpy.array(f.get("target")) + mask: numpy.typing.NDArray[numpy.bool_] = numpy.array(f.get("mask")) + + image *= mask + pred *= mask + target = numpy.logical_and(target, mask) + + if show_errors: + tp, fp, _, fn = tfpn_masks(pred, target, threshold) + + # change to PIL representation + tp_pil = _to_pil(tp.astype(numpy.float32)) + tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0, 0, 0), tp_color) + + fp_pil = _to_pil(fp.astype(numpy.float32)) + fp_pil_colored = PIL.ImageOps.colorize(fp_pil, (0, 0, 0), fp_color) + + fn_pil = _to_pil(fn.astype(numpy.float32)) + fn_pil_colored = PIL.ImageOps.colorize(fn_pil, (0, 0, 0), fn_color) + + tp_pil_colored.paste(fp_pil_colored, mask=fp_pil) + tp_pil_colored.paste(fn_pil_colored, mask=fn_pil) + + else: + overlay = pred >= threshold + tp_pil = _to_pil(overlay.astype(numpy.float32)) + tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0, 0, 0), tp_color) + + retval = _to_pil(image) + return PIL.Image.blend(retval, tp_pil_colored, alpha) diff --git a/src/mednet/libs/segmentation/scripts/cli.py b/src/mednet/libs/segmentation/scripts/cli.py index 2331fd13..f2841daf 100644 --- a/src/mednet/libs/segmentation/scripts/cli.py +++ b/src/mednet/libs/segmentation/scripts/cli.py @@ -9,13 +9,10 @@ from clapper.click import AliasedGroup from . import ( # analyze, - # compare, config, database, evaluate, predict, - # mkmask, - # significance, train, view, ) @@ -31,11 +28,8 @@ def segmentation(): # segmentation.add_command(analyze.analyze) -# segmentation.add_command(compare.compare) segmentation.add_command(config.config) segmentation.add_command(database.database) -# segmentation.add_command(mkmask.mkmask) -# segmentation.add_command(significance.significance) segmentation.add_command(train.train) segmentation.add_command(predict.predict) segmentation.add_command(evaluate.evaluate) diff --git a/src/mednet/libs/segmentation/scripts/view.py b/src/mednet/libs/segmentation/scripts/view.py index 30c3b355..e0cb9706 100644 --- a/src/mednet/libs/segmentation/scripts/view.py +++ b/src/mednet/libs/segmentation/scripts/view.py @@ -1,171 +1,21 @@ -# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later +import json import pathlib +import typing import click -import h5py -import PIL.Image -import torch -from clapper.click import ResourceOption +import tqdm +from clapper.click import ResourceOption, verbosity_option +from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand -from PIL import ImageColor -from PIL.ImageChops import invert, logical_and -from torchvision.transforms.functional import to_pil_image - - -def get_tp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image: - """Compute the true positive mask. - - Parameters - ---------- - binary_image - B/W image to compare to the target. - binary_target - B/W reference image. - - Returns - ------- - Image with white pixels where both image and target are white, black pixels otherwise. - """ - return logical_and(binary_image, binary_target) - - -def get_tn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image: - """Compute the false positive mask. - - Parameters - ---------- - binary_image - B/W image to compare to the target. - binary_target - B/W reference image. - - Returns - ------- - Image with white pixels where both image and target are black, black pixels otherwise. - """ - return logical_and(invert(binary_image), invert(binary_target)) - - -def get_fp_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image: - """Compute the true positive mask. - - Parameters - ---------- - binary_image - B/W image to compare to the target. - binary_target - B/W reference image. - - Returns - ------- - Image with white pixels where image is white and target is black, black pixels otherwise. - """ - return logical_and(binary_image, invert(binary_target)) - - -def get_fn_mask(binary_image: PIL.Image, binary_target: PIL.Image) -> PIL.Image: - """Compute the true positive mask. - - Parameters - ---------- - binary_image - B/W image to compare to the target. - binary_target - B/W reference image. - - Returns - ------- - Image with white pixels where image is black and target is white, black pixels otherwise. - """ - return logical_and(invert(binary_image), binary_target) - - -def get_masks( - binary_prediction_image: PIL.Image, binary_target_image: PIL.Image -) -> tuple[PIL.Image.Image, PIL.Image, PIL.Image, PIL.Image]: - """Given a B/W binary image and a target, return the tp, tn, fp, fn masks. - - Parameters - ---------- - binary_prediction_image - B/W image. - binary_target_image - B/W reference image. - - Returns - ------- - The tp, tn, fp, fn masks - """ - tp_mask = get_tp_mask(binary_prediction_image, binary_target_image) - tn_mask = get_tn_mask(binary_prediction_image, binary_target_image) - fp_mask = get_fp_mask(binary_prediction_image, binary_target_image) - fn_mask = get_fn_mask(binary_prediction_image, binary_target_image) - - return tp_mask, tn_mask, fp_mask, fn_mask - - -def load_image_from_hdf5(filepath: pathlib.Path, key: str) -> PIL.Image: - """Load an image located in an hdf5 file by key. - - Parameters - ---------- - filepath - Path to an hdf5 file. - key - Key to search for in the hdf5 file. - - Returns - ------- - The loaded PIL.Image. - """ - with h5py.File(filepath, "r") as f: - img = to_pil_image(torch.from_numpy(f.get(key)[:])) - - return img # noqa: RET504 - - -def image_to_binary(image: PIL.Image, threshold: int = 127) -> PIL.Image: - """Change the mode of a PIL image to '1' (binary) using a threshold. - - Parameters - ---------- - image - The image to convert to binary mode. - threshold - The threshold to use for convertion, with p <= threshold = black, p > threshold = white. - - Returns - ------- - The binary image. - """ - image = image.point(lambda p: 255 if p > threshold else 0) - return image.convert("1") - - -def color_with_mask(image: PIL.Image, mask: PIL.Image, color: ImageColor) -> PIL.Image: - """Colorize the image with a given color by using a mask. - - Parameters - ---------- - image - The image to colorize. - mask - Mask used to indicate where to apply the color. - color - The color to apply. - - Returns - ------- - The colorized image. - """ - color_plane = PIL.Image.new(mode="RGB", size=image.size, color=color) - - image = image.convert("RGB") - image.paste(color_plane, mask) - return image +from mednet.libs.segmentation.engine.evaluator import SUPPORTED_METRIC_TYPE + +from .evaluate import validate_threshold + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.command( @@ -174,66 +24,184 @@ def color_with_mask(image: PIL.Image, mask: PIL.Image, color: ImageColor) -> PIL epilog="""Examples: \b - 1. Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized: + 1. Runs evaluation on an existing dataset configuration: .. code:: sh - $ mednet segmentation view -f results/predictions/test/test.hdf5 -o colorized_prediction.png - + $ mednet segmentation view -vv --predictions=path/to/predictions.json --output-folder=path/to/results """, ) @click.option( - "--hdf5-file", - "-f", - help="File in which predictions are currently stored", + "--predictions", + "-p", + help="""Path to the JSON file describing available predictions. The actual + predictions are supposed to lie on the same folder.""", required=True, type=click.Path( file_okay=True, dir_okay=False, - writable=True, + writable=False, path_type=pathlib.Path, ), cls=ResourceOption, ) @click.option( - "--output-file", + "--output-folder", "-o", - help="File in which to store the result (created if does not exist)", + help="Directory in which to store results (created if does not exist)", required=True, type=click.Path( - file_okay=True, - dir_okay=False, + file_okay=False, + dir_okay=True, writable=True, path_type=pathlib.Path, ), - default="segmentation.png", + default="results", + cls=ResourceOption, +) +@click.option( + "--threshold", + "-t", + help="""This number is used to define positives and negatives from + probability maps, and used to report metrics based on a threshold chosen *a + priori*. It can be set to a floating-point value, or to the name of dataset + split in ``--predictions``. + """, + default="0.5", + show_default=True, + required=False, + cls=ResourceOption, +) +@click.option( + "--metric", + "-m", + help="""If threshold is set to the name of a split in ``--predictions``, + then this parameter defines the metric function to be used to evaluate the + threshold at which the metric reaches its maximum value. All other splits + are evaluated with respect to this threshold.""", + default="f1", + type=click.Choice(typing.get_args(SUPPORTED_METRIC_TYPE), case_sensitive=True), + show_default=True, + required=True, + cls=ResourceOption, +) +@click.option( + "--steps", + "-s", + help="""Number of steps for evaluating metrics on various splits. This + value is used when drawing precision-recall plots, or when deciding the + highest metric value on splits.""", + default=100, + type=click.IntRange(10), + show_default=True, + required=True, + cls=ResourceOption, +) +@click.option( + "--show-errors/--no-show-errors", + "-e/-E", + help="""If set, then shows a colorized version of the segmentation map in + which false-positives are marked in red, and false-negatives in green. + True positives are always marked in white.""", + default=False, + show_default=True, + required=True, cls=ResourceOption, ) +@click.option( + "--alpha", + "-a", + help="""Defines the transparency weighting between the original image and + the predicted segmentation maps. A value of 1.0 makes the program output + only segmentation maps. A value of 0.0 makes the program output only the + processed image.""", + default=0.6, + type=click.FloatRange(0.0, 1.0), + show_default=True, + required=True, + cls=ResourceOption, +) +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def view( - hdf5_file: pathlib.Path, - output_file: pathlib.Path, + predictions: pathlib.Path, + output_folder: pathlib.Path, + threshold: str | float, + metric: str, + steps: int, + show_errors: bool, + alpha: float, **_, # ignored ): # numpydoc ignore=PR01 - """Load images from an hdf5 file and saves a new image with the tp, tn, fp, fn colorized.""" - colors_dict = { - "tp": ImageColor.getcolor("white", "RGB"), - "tn": ImageColor.getcolor("black", "RGB"), - "fp": ImageColor.getcolor("green", "RGB"), - "fn": ImageColor.getcolor("red", "RGB"), - } - - pred_img = load_image_from_hdf5(hdf5_file, "img") - binary_pred_img = image_to_binary(pred_img, 127) - - target_img = load_image_from_hdf5(hdf5_file, "target") - binary_target_img = image_to_binary(target_img, 127) - - tp_mask, tn_mask, fp_mask, fn_mask = get_masks(binary_pred_img, binary_target_img) - - colorized_image = pred_img - colorized_image = color_with_mask(colorized_image, tp_mask, colors_dict["tp"]) - colorized_image = color_with_mask(colorized_image, tn_mask, colors_dict["tn"]) - colorized_image = color_with_mask(colorized_image, fp_mask, colors_dict["fp"]) - colorized_image = color_with_mask(colorized_image, fn_mask, colors_dict["fn"]) - - colorized_image.save(output_file) + """Evaluate predictions (from a model) on a segmentation task.""" + + import numpy + from mednet.libs.common.scripts.utils import ( + execution_metadata, + save_json_with_backup, + ) + from mednet.libs.segmentation.engine.evaluator import ( + compute_metric, + load_count, + name2metric, + ) + from mednet.libs.segmentation.engine.viewer import view + + evaluation_filename = "evaluation.json" + evaluation_file = output_folder / evaluation_filename + + with predictions.open("r") as f: + predict_data = json.load(f) + + # register metadata + json_data: dict[str, typing.Any] = execution_metadata() + json_data.update( + dict( + predictions=str(predictions), + output_folder=str(output_folder), + threshold=threshold, + metric=metric, + steps=steps, + ), + ) + json_data = {k.replace("_", "-"): v for k, v in json_data.items()} + save_json_with_backup(evaluation_file.with_suffix(".meta.json"), json_data) + + threshold = validate_threshold(threshold, predict_data) + threshold_list = numpy.arange( + 0.0, (1.0 + 1 / steps), 1 / steps, dtype=numpy.float64 + ) + + if isinstance(threshold, str): + # Compute threshold on specified split, if required + logger.info(f"Evaluating threshold on `{threshold}` split using " f"`{metric}`") + counts = load_count(predictions.parent, predict_data[threshold], threshold_list) + metric_list = compute_metric( + counts, name2metric(typing.cast(SUPPORTED_METRIC_TYPE, metric)) + ) + threshold_index = metric_list.argmax() + logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}") + + else: + # must figure out the closest threshold from the list we are using + threshold_index = (numpy.abs(threshold_list - threshold)).argmin() + logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}") + + # create visualisations + for split_name, sample_list in predict_data.items(): + logger.info( + f"Creating {len(sample_list)} visualisations for split `{split_name}`" + ) + for sample in tqdm.tqdm(sample_list): + image = view( + predictions.parent, + sample[1], + threshold=threshold_list[threshold_index], + show_errors=show_errors, + tp_color=(255, 255, 255), + fp_color=(255, 0, 0), + fn_color=(0, 255, 0), + alpha=alpha, + ) + dest = (output_folder / sample[1]).with_suffix(".png") + tqdm.tqdm.write(f"{sample[1]} -> {dest}") + image.save(dest) -- GitLab