diff --git a/src/ptbench/engine/visualizer.py b/src/ptbench/engine/visualizer.py deleted file mode 100644 index fed11421bc9380b57c3dcda2846ba8bfb5ebd3eb..0000000000000000000000000000000000000000 --- a/src/ptbench/engine/visualizer.py +++ /dev/null @@ -1,231 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import logging -import os - -import cv2 -import numpy as np -import pandas as pd - -from torchvision.transforms.functional import to_pil_image -from tqdm import tqdm - -from ..utils.cam_utils import ( - draw_boxes_on_image, - draw_largest_component_bbox_on_image, - show_cam_on_image, - visualize_road_scores, -) - -logger = logging.getLogger(__name__) - - -def _get_target_classes_from_directory(input_folder): - # Gets a list of target classes from a directory - return [ - item - for item in os.listdir(input_folder) - if os.path.isdir(os.path.join(input_folder, item)) - ] - - -def run( - data_loader, - input_folder, - output_folder, - dataset_split_name, - visualize_groundtruth=False, - road_path=None, - visualize_detected_bbox=True, - threshold=0.0, -): - """Overlays saliency maps on CXR to output final images with heatmaps. - - Parameters - --------- - img_dir : str - Directory containing the images from the dataset to use for visualization. - - input_folder : str - Directory in which the saliency maps are stored for a specific visualization type. - - output_folder : str - Directory in which the results will be saved. - - visualize_groundtruth : bool - If set, generate visualizations for ground truth labels. - - road_path : str - If the path to the previously calculated road scores is provided, MoRF, LeRF, and combined ROAD scores will be visualized on each image. - - visualize_detected_bbox : bool - If set, the bounding box for the largest connected component will be visualized on each image. - - threshold : float - Only CAMs above this threshold will be visualized. - - Returns - ------- - - output_folder : str - Directory containing the images overlaid with heatmaps. - """ - - target_classes = _get_target_classes_from_directory(input_folder) - - shown_warnings = set() - - if visualize_detected_bbox: - # Load detected bounding box information - localization_metric_dfs = {} - - for target_class in target_classes: - input_base_path = f"{input_folder}/{target_class}/{dataset_split_name}_localization_metrics.csv" - - if os.path.exists(input_base_path): - df = pd.read_csv(input_base_path) - localization_metric_dfs[target_class] = df - - if road_path is not None: - # Load road score information - road_metric_dfs = {} - - for target_class in target_classes: - road_base_path = f"{road_path}/{target_class}/{dataset_split_name}_road_metrics.csv" - - if os.path.exists(road_base_path): - df = pd.read_csv(road_base_path) - road_metric_dfs[target_class] = df - - for samples in tqdm(data_loader, desc="batches", leave=False, disable=None): - includes_bboxes = True - if ( - samples[1]["label"].item() == 0 - or "bounding_boxes" not in samples[1] - ): - includes_bboxes = False - else: - gt_bboxes = samples[1]["bounding_boxes"] - if not gt_bboxes: - includes_bboxes = False - - names = samples[1]["name"] - - images = samples[0] - - if ( - visualize_groundtruth - and not includes_bboxes - and samples[1]["label"].item() == 1 - ): - logger.warning( - f'Sample "{names[0]}" does not have bounding box information. No ground truth bounding boxes can be visualized.' - ) - - for target_class in target_classes: - input_base_path = ( - f"{input_folder}/{target_class}/{dataset_split_name}" - ) - vis_directory_name = os.path.basename(input_folder) - output_base_path = f"{output_folder}/{vis_directory_name}/{target_class}/{dataset_split_name}" - - saliency_map_name = names[0].rsplit(".", 1)[0] + ".npy" - saliency_map_path = os.path.join(input_base_path, saliency_map_name) - - # Skip if the saliency map does not exist (e.g. if no saliency maps were generated for TB negative samples) - if not os.path.exists(saliency_map_path): - continue - - saliency_map = np.load(saliency_map_path) - - # Make sure the image is RGB - pil_img = to_pil_image(images[0]) - pil_img = pil_img.convert("RGB") - - # Image is expected to be of type float32 - original_img = np.array(pil_img, dtype=np.float32) - img_with_boxes = original_img.copy() - - # Draw bounding boxes on the image - if visualize_groundtruth and includes_bboxes: - img_with_boxes = draw_boxes_on_image(img_with_boxes, gt_bboxes) - - # show_cam_on_image expects the img to be between [0, 1] - rgb_img = img_with_boxes / 255.0 - - # Threshold heatmap - thresholded_mask = np.copy(saliency_map) - thresholded_mask[thresholded_mask < threshold] = 0 - - # Visualize heatmap - visualization = show_cam_on_image( - rgb_img, thresholded_mask, use_rgb=False - ) - - # Visualize bounding box for largest connected component - if visualize_detected_bbox: - if target_class in localization_metric_dfs: - temp_df = localization_metric_dfs[target_class] - temp_filename = names[0] - - row = temp_df[temp_df["Image"] == temp_filename] - - if not row.empty: - x_L = int(row["detected_bbox_xmin"].values[0]) - y_L = int(row["detected_bbox_ymin"].values[0]) - w_L = int(row["detected_bbox_width"].values[0]) - h_L = int(row["detected_bbox_height"].values[0]) - - visualization = draw_largest_component_bbox_on_image( - visualization, x_L, y_L, w_L, h_L - ) - else: - logger.warning( - f"Could not find entry for {temp_filename} in .csv file." - ) - else: - error_message = f"No localization metrics .csv file found for the {dataset_split_name} split for {target_class}." - if error_message not in shown_warnings: - logger.warning(error_message) - shown_warnings.add(error_message) - - # Visualize ROAD scores - if road_path is not None: - if target_class in road_metric_dfs: - temp_df = road_metric_dfs[target_class] - temp_filename = names[0] - - row = temp_df[temp_df["Image"] == temp_filename] - - if not row.empty: - MoRF_score = float(row["MoRF"].values[0]) - LeRF_score = float(row["LeRF"].values[0]) - combined_score = float( - row["Combined Score ((LeRF-MoRF) / 2)"].values[0] - ) - percentiles = str(row["Percentiles"].values[0]) - - visualization = visualize_road_scores( - visualization, - MoRF_score, - LeRF_score, - combined_score, - vis_directory_name, - percentiles, - ) - else: - logger.warning( - f"Could not find entry for {temp_filename} in .csv file." - ) - else: - error_message = f"No ROAD metrics .csv file found for the {dataset_split_name} split for {target_class}." - if error_message not in shown_warnings: - logger.warning(error_message) - shown_warnings.add(error_message) - - # Save image - output_file_path = os.path.join(output_base_path, names[0]) - os.makedirs(os.path.dirname(output_file_path), exist_ok=True) - cv2.imwrite(output_file_path, visualization) diff --git a/src/ptbench/scripts/compare_vis.py b/src/ptbench/scripts/compare_vis.py deleted file mode 100644 index bade26b87c4e53b2d18bfe1c9450b8937b8b6b08..0000000000000000000000000000000000000000 --- a/src/ptbench/scripts/compare_vis.py +++ /dev/null @@ -1,207 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import os - -import click - -from clapper.click import verbosity_option -from clapper.logging import setup - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -def _sorting_rule(folder_name): - folder_name = os.path.basename(folder_name) - - if "gradcam" == folder_name: - return (0, folder_name) - elif "scorecam" == folder_name: - return (1, folder_name) - elif "fullgrad" == folder_name: - return (2, folder_name) - elif "randomcam" == folder_name: - return (4, folder_name) - else: # Everything else will be sorted alphabetically after fullgrad and before randomcam - return (3, folder_name) - - -def _get_images_from_directory(dir_path): - image_files = [] - for root, _, files in os.walk(dir_path): - for file in files: - if file.lower().endswith((".png", ".jpg", ".jpeg")): - image_files.append(os.path.join(root, file)) - return image_files - - -@click.command( - epilog="""Examples: - -\b - 1. Compares different visualization techniques by showing their results side-by-side - (on a per image basis). The input_folder should be a parent folder containing different - folders for each visualization technique (e.g. input_folder/gradcam, input_folder/scorecam, etc.): - - .. code:: sh - - ptbench compare-vis -i path/to/input_folder -o path/to/output_folder -""", -) -@click.option( - "--input-folder", - "-i", - help="Path to the parent folder of resulted visualizations from visualize command", - required=True, - type=click.Path(), -) -@click.option( - "--output-folder", - "-o", - help="Path where to store the cam comparisons (created if does not exist)", - required=True, - default="cam_comparisons", - type=click.Path(), -) -@verbosity_option(logger=logger, expose_value=False) -def compare_vis(input_folder, output_folder) -> None: - """Compares multiple visualization techniques by showing their results in - one image.""" - - import os - - import cv2 - import matplotlib.pyplot as plt - - assert ( - input_folder != output_folder - ), "Output folder must not be the same as the input folder." - - # Check that the output folder is not a subdirectory of the input_folder - assert not output_folder.startswith( - input_folder - ), "Output folder must not be a subdirectory of the input folder." - - if not os.path.exists(output_folder): - os.makedirs(output_folder) - - dir_element_names = os.listdir(input_folder) - - # Sort folders by visualization type - dir_element_names.sort(key=_sorting_rule) - - # Get the list of subfolders - vis_type_folders = [ - os.path.join(input_folder, d) - for d in dir_element_names - if os.path.isdir(os.path.join(input_folder, d)) - ] - - if not vis_type_folders: - raise ValueError("No subdirectories found in the parent folder.") - - for vis_type_folder in vis_type_folders: - target_class_folders = next(os.walk(vis_type_folder))[1] - for target_class_folder in target_class_folders: - current_target_class_dir = os.path.join( - vis_type_folder, target_class_folder - ) - dataset_type_folders = next(os.walk(current_target_class_dir))[1] - for dataset_type_folder in dataset_type_folders: - comparison_folders = [ - os.path.join( - input_folder, - vis_type, - target_class_folder, - dataset_type_folder, - ) - for vis_type in dir_element_names - ] - - valid_folders = [] - - for folder in comparison_folders: - if not os.path.exists(folder): - logger.warning(f"Folder does not exist: {folder}") - else: - valid_folders.append(folder) - - comparison_folders = valid_folders - - output_directory = os.path.join( - output_folder, target_class_folder, dataset_type_folder - ) - os.makedirs(output_directory, exist_ok=True) - - # Use a set (unordered collection of unique elements) for efficient membership tests - image_names = { - os.path.basename(img) - for img in _get_images_from_directory(comparison_folders[0]) - } - - # Only keep image names that exist in all folders - for folder in comparison_folders[1:]: - # This is basically an intersection-check of contents of different folders - # Images that don't exist in all folders are removed from the set - image_names &= { - os.path.basename(img) - for img in _get_images_from_directory(folder) - } - - if not image_names: - raise ValueError("No common images found in the folders.") - - max_cols = min(5, len(comparison_folders)) - rows, cols = ( - len(comparison_folders) // max_cols, - len(comparison_folders) % max_cols, - ) - rows = rows if cols == 0 else rows + 1 - - for image_name in image_names: - fig, axs = plt.subplots( - rows, max_cols, figsize=(4 * max_cols, 4 * rows) - ) - fig.subplots_adjust( - left=0, - right=1, - bottom=0, - top=1, - wspace=0.05, - hspace=0.1, - ) - axs = axs.ravel() - for i, folder in enumerate(comparison_folders): - image_path = [ - img - for img in _get_images_from_directory(folder) - if os.path.basename(img) == image_name - ][0] - try: - img = cv2.imread(image_path) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - except cv2.error as e: - raise RuntimeError( - "Error reading or processing image: " - + image_path - ) from e - axs[i].imshow(img) - axs[i].set_title( - os.path.basename( - os.path.dirname(os.path.dirname(folder)) - ), - fontsize=20, - ) - axs[i].axis("off") - - # Fill the remaining columns with empty plots (white) - for i in range(len(comparison_folders), rows * max_cols): - axs[i].axis("off") - - plt.savefig( - os.path.join(output_directory, f"{image_name}"), - bbox_inches="tight", - pad_inches=0, - ) - plt.close(fig) diff --git a/src/ptbench/scripts/evaluatevis.py b/src/ptbench/scripts/evaluatevis.py deleted file mode 100644 index ff67b7020fbea38d8bc4dc7a0ec0cc10fa1628b3..0000000000000000000000000000000000000000 --- a/src/ptbench/scripts/evaluatevis.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import click - -from clapper.click import verbosity_option -from clapper.logging import setup - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -@click.command( - epilog="""Examples: - -\b - 1. Takes all the .csv files resulting from the visualize command and generates - a summary.csv file containing a summary of the results. - - .. code:: sh - - ptbench evaluatevis -i path/to/input_folder -""", -) -@click.option( - "--input-folder", - "-i", - help="Path to the folder including the .csv files with the individual scores.", - required=True, - type=click.Path(), -) -@verbosity_option(logger=logger, expose_value=False) -def evaluatevis(input_folder) -> None: - """Calculates summary statistics for the scores for one visualization - type.""" - - from ..utils.cam_utils import calculate_metrics_avg_for_every_class - - calculate_metrics_avg_for_every_class(input_folder) diff --git a/src/ptbench/scripts/visualize.py b/src/ptbench/scripts/visualize.py deleted file mode 100644 index 79097c3ac2f4811edee77c8cc85a5f5b1999cd40..0000000000000000000000000000000000000000 --- a/src/ptbench/scripts/visualize.py +++ /dev/null @@ -1,158 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import os -import pathlib - -import click - -from clapper.click import ConfigCommand, ResourceOption, verbosity_option -from clapper.logging import setup - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -@click.command( - entry_point_group="ptbench.config", - cls=ConfigCommand, - epilog="""Examples: - -\b - 1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration: - - .. code:: sh - - ptbench visualize -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --road-path=parent_folder/gradcam/ --output-folder=path/to/visualizations - -""", -) -@click.option( - "--model", - "-m", - help="A lightining module instance implementing the network to be used for applying the necessary data transformations.", - required=True, - cls=ResourceOption, -) -@click.option( - "--datamodule", - "-d", - help="A lighting data module containing the training, validation and test sets.", - required=True, - cls=ResourceOption, -) -@click.option( - "--input-folder", - "-i", - help="Path to the folder containing the saliency maps for a specific visualization type.", - required=True, - type=click.Path( - file_okay=False, - dir_okay=True, - writable=True, - path_type=pathlib.Path, - ), - default="visualizations", - cls=ResourceOption, -) -@click.option( - "--output-folder", - "-o", - help="Path where to store the ROAD scores (created if does not exist)", - required=True, - type=click.Path( - file_okay=False, - dir_okay=True, - writable=True, - path_type=pathlib.Path, - ), - default="visualizations", - cls=ResourceOption, -) -@click.option( - "--visualize-groundtruth", - "-vgt", - help="If set, visualizations for ground truth labels will be generated. Only works for datasets with bounding boxes.", - is_flag=True, - default=False, - cls=ResourceOption, -) -@click.option( - "--road-path", - "-r", - help="If the path to the previously calculated ROAD scores is provided, MoRF, LeRF, and combined ROAD scores will be visualized on each image.", - required=False, - default=None, - cls=ResourceOption, - type=click.Path(), -) -@click.option( - "--visualize-detected-bbox", - "-vb", - help="If set, largest component bounding boxes will be visualized on each image. Only works if the bounding boxes have been previously generated.", - is_flag=True, - default=False, - cls=ResourceOption, -) -@click.option( - "--threshold", - "-t", - help="Only activations above this threshold will be visualized.", - show_default=True, - required=True, - default=0.0, - type=click.FloatRange(min=0, max=1), - cls=ResourceOption, -) -@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) -def visualize( - model, - datamodule, - input_folder, - output_folder, - visualize_groundtruth, - road_path, - visualize_detected_bbox, - threshold, - **_, -) -> None: - """Generates heatmaps for input CXRs based on existing saliency maps.""" - - from ..engine.visualizer import run - from .utils import save_sh_command - - assert ( - input_folder != output_folder - ), "Output folder must not be the same as the input folder." - - assert not str(output_folder).startswith( - str(input_folder) - ), "Output folder must not be a subdirectory of the input folder." - - logger.info(f"Output folder: {output_folder}") - os.makedirs(output_folder, exist_ok=True) - - save_sh_command(output_folder / "command.sh") - - datamodule.set_chunk_size(1, 1) - datamodule.drop_incomplete_batch = False - # datamodule.cache_samples = cache_samples - # datamodule.parallel = parallel - datamodule.model_transforms = model.model_transforms - - datamodule.prepare_data() - datamodule.setup(stage="predict") - - dataloaders = datamodule.predict_dataloader() - - for k, v in dataloaders.items(): - run( - data_loader=v, - input_folder=input_folder, - output_folder=output_folder, - dataset_split_name=k, - visualize_groundtruth=visualize_groundtruth, - road_path=road_path, - visualize_detected_bbox=visualize_detected_bbox, - threshold=threshold, - ) diff --git a/src/ptbench/utils/cam_utils.py b/src/ptbench/utils/cam_utils.py deleted file mode 100644 index f379e644d7fb1a73059fcc6597ee77cb633ec444..0000000000000000000000000000000000000000 --- a/src/ptbench/utils/cam_utils.py +++ /dev/null @@ -1,292 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import glob -import os -import shutil - -import cv2 -import numpy as np -import pandas as pd - - -def _calculate_stats_over_dataset(file_path): - data = pd.read_csv(file_path) - - # Compute mean, median and standard deviation for each score type - metrics = {} - for column in [ - "MoRF", - "LeRF", - "Combined Score ((LeRF-MoRF) / 2)", - "IoU", - "IoDA", - "propEnergy", - "ASF", - ]: - mean = round(data[column].mean(), 3) - median = round(data[column].median(), 3) - std_dev = round(data[column].std(), 3) - q1 = round(data[column].quantile(0.25), 3) - q3 = round(data[column].quantile(0.75), 3) - - metrics[column] = { - "mean": mean, - "median": median, - "std_dev": std_dev, - "Q1": q1, - "Q3": q3, - } - - return metrics - - -def calculate_metrics_avg_for_every_class(input_folder): - directories = [ - dir - for dir in os.listdir(input_folder) - if os.path.isdir(os.path.join(input_folder, dir)) - ] - - # Get unique file names ending with .csv without the extension - csv_file_names = set() - for directory in directories: - dir_path = os.path.join(input_folder, directory) - csv_files_in_dir = glob.glob(os.path.join(dir_path, "*.csv")) - for csv_file in csv_files_in_dir: - csv_file_name = os.path.splitext(os.path.basename(csv_file))[0] - csv_file_names.add(csv_file_name) - - all_results = {file_name: [] for file_name in csv_file_names} - - for file_name in csv_file_names: - previous_summary = os.path.join( - input_folder, f"{file_name}_summary.csv" - ) - if os.path.exists(previous_summary): - backup = previous_summary + "~" - if os.path.exists(backup): - os.unlink(backup) - shutil.move(previous_summary, backup) - - for directory in directories: - dir_path = os.path.join(input_folder, directory) - - for file_name in csv_file_names: - csv_file = os.path.join(dir_path, f"{file_name}.csv") - - if os.path.exists(csv_file): - class_name = directory - metrics = _calculate_stats_over_dataset(csv_file) - - for score_type, values in metrics.items(): - result = { - "Class Name": class_name, - "Score Type": score_type, - "Mean": values["mean"], - "Standard Deviation": values["std_dev"], - "Median": values["median"], - "Q1": values["Q1"], - "Q3": values["Q3"], - } - all_results[file_name].append(result) - - for file_name, results in all_results.items(): - df_results = pd.DataFrame(results) - df_results.to_csv( - os.path.join(input_folder, f"{file_name}_summary.csv"), index=False - ) - - -def draw_boxes_on_image(img_with_boxes, bboxes): - for i, bbox in enumerate(bboxes): - xmin, ymin = int(bbox[1].item()), int(bbox[2].item()) - width, height = int(bbox[3].item()), int(bbox[4].item()) - xmax = int(xmin + width) - ymax = int(ymin + height) - cv2.rectangle( - img_with_boxes, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2 - ) - - # Add text to the bbox - text = f"GTruth bbox {i+1}" - font_scale = max( - 0.5, min(width, height) / 200.0 - ) # Ensure minimum font scale of 0.5 - font_scale = min(font_scale, 1.3) # Ensure maximum font scale of 1.3 - - # Computes text height - (text_width, text_height), baseline = cv2.getTextSize( - text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 2 - ) - - # Make sure not to draw text outside of image - if ymin + height + text_height + 2 < img_with_boxes.shape[0]: - y = ymin + height + text_height + 2 # put below box - else: - y = ymin - text_height - 2 # put above box - - y = int(y) - - cv2.putText( - img_with_boxes, - text, - (xmin + 5, y), - cv2.FONT_HERSHEY_SIMPLEX, - font_scale, - (36, 255, 12), - 2, - ) - return img_with_boxes - - -def draw_largest_component_bbox_on_image(img_with_boxes, x, y, w, h): - cv2.rectangle(img_with_boxes, (x, y), (x + w, y + h), (0, 0, 255), 2) - - # Add text to the bbox - text = "Detected Area" - font_scale = max(0.5, min(w, h) / 200.0) # Ensure minimum font scale of 0.5 - font_scale = min(font_scale, 1.3) # Ensure maximum font scale of 1.3 - - # Computes text height - (text_width, text_height), baseline = cv2.getTextSize( - text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 2 - ) - - # Make sure not to draw text outside bottom of image - if y + h + text_height + 2 < img_with_boxes.shape[0]: - y = y + h + text_height + 2 # put below box - else: - y = y - text_height - 2 # put above box - - y = int(y) - - cv2.putText( - img_with_boxes, - text, - (x + 5, y), - cv2.FONT_HERSHEY_SIMPLEX, - font_scale, - (0, 0, 255), - 2, - ) - return img_with_boxes - - -def show_cam_on_image( - img: np.ndarray, - thresholded_mask: np.ndarray, - use_rgb: bool = False, - colormap: int = cv2.COLORMAP_JET, - image_weight: float = 0.5, -) -> np.ndarray: - """This function overlays the cam mask on the image as an heatmap. By - default the heatmap is in BGR format. - - :param img: The base image in RGB or BGR format. - :param mask: The cam mask. - :param threshold: Only apply heatmap to areas where cam is above this threshold. - :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. - :param colormap: The OpenCV colormap to be used. - :param image_weight: The final result is image_weight * img + (1-image_weight) * mask. - :returns: The default image with the cam overlay. - - This is a slightly modified version of the show_cam_on_image implementation in: - https://github.com/jacobgil/pytorch-grad-cam - """ - - if img.shape[:2] != thresholded_mask.shape: - raise ValueError( - "The shape of the mask should be the same as the shape of the image." - ) - - heatmap = cv2.applyColorMap(np.uint8(255 * thresholded_mask), colormap) - if use_rgb: - heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) - heatmap = np.float32(heatmap) / 255 - - if np.max(img) > 1: - raise Exception("The input image should np.float32 in the range [0, 1]") - - if image_weight < 0 or image_weight > 1: - raise Exception( - f"image_weight should be in the range [0, 1].\ - Got: {image_weight}" - ) - - # For pixels where the mask is zero, - # the original image pixels are being used without a mask. - cam = np.where( - thresholded_mask[..., np.newaxis] == 0, - img, - (1 - image_weight) * heatmap + image_weight * img, - ) - cam = cam / np.max(cam) - return np.uint8(255 * cam) - - -def visualize_road_scores( - visualization, MoRF_score, LeRF_score, combined_score, name, percentiles -): - visualization = cv2.putText( - visualization, - name, - (10, 20), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - cv2.LINE_AA, - ) - visualization = cv2.putText( - visualization, - f"Percentiles: {percentiles}", - (10, 40), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - cv2.LINE_AA, - ) - visualization = cv2.putText( - visualization, - "Remove and Debias", - (10, 55), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - cv2.LINE_AA, - ) - visualization = cv2.putText( - visualization, - f"MoRF score: {MoRF_score:.5f}", - (10, 70), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - cv2.LINE_AA, - ) - visualization = cv2.putText( - visualization, - f"LeRF score: {LeRF_score:.5f}", - (10, 85), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - cv2.LINE_AA, - ) - visualization = cv2.putText( - visualization, - f"(LeRF-MoRF)/2: {combined_score:.5f}", - (10, 100), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - cv2.LINE_AA, - ) - return visualization