From c3bebb065806751adeeac2fc60843786aa453d82 Mon Sep 17 00:00:00 2001 From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch> Date: Wed, 20 Sep 2023 01:01:39 +0200 Subject: [PATCH] added saliency map evaluation and visualisation for lightning models --- src/ptbench/engine/saliencymap_evaluator.py | 41 +++++---- src/ptbench/engine/visualizer.py | 70 ++++++++------- src/ptbench/scripts/calculate_road.py | 12 ++- src/ptbench/scripts/evaluate_saliencymaps.py | 64 +++++++------- src/ptbench/scripts/generate_saliencymaps.py | 10 +-- src/ptbench/scripts/visualize.py | 89 +++++++++++--------- src/ptbench/utils/cam_utils.py | 7 +- 7 files changed, 150 insertions(+), 143 deletions(-) diff --git a/src/ptbench/engine/saliencymap_evaluator.py b/src/ptbench/engine/saliencymap_evaluator.py index 1285dc95..07d9dde8 100644 --- a/src/ptbench/engine/saliencymap_evaluator.py +++ b/src/ptbench/engine/saliencymap_evaluator.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import ast import logging import os @@ -29,10 +28,8 @@ def _compute_max_iou_and_ioda(detected_box, gt_boxes): max_gt_area = 0 for bbox in gt_boxes: - unpacked_bbox = bbox[0] - bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"')) - xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"]) - width, height = int(bbox_dict["width"]), int(bbox_dict["height"]) + xmin, ymin = int(bbox[1].item()), int(bbox[2].item()) + width, height = int(bbox[3].item()), int(bbox[4].item()) gt_area = width * height @@ -78,10 +75,8 @@ def _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes): total_gt_area = 0 for bbox in gt_boxes: - unpacked_bbox = bbox[0] - bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"')) - xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"]) - width, height = int(bbox_dict["width"]), int(bbox_dict["height"]) + xmin, ymin = int(bbox[1].item()), int(bbox[2].item()) + width, height = int(bbox[3].item()), int(bbox[4].item()) gt_area = width * height total_gt_area += gt_area @@ -112,10 +107,8 @@ def _compute_avg_saliency_focus(gt_boxes, saliency_map): # For each gt box, draw a binary mask # The binary_mask will be 1 where the gt boxes are located for bbox in gt_boxes: - unpacked_bbox = bbox[0] - bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"')) - xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"]) - width, height = int(bbox_dict["width"]), int(bbox_dict["height"]) + xmin, ymin = int(bbox[1].item()), int(bbox[2].item()) + width, height = int(bbox[3].item()), int(bbox[4].item()) binary_mask[ymin : ymin + height, xmin : xmin + width] = 1 @@ -143,10 +136,8 @@ def _compute_proportional_energy(gt_boxes, saliency_map): # For each gt box, draw a binary mask # The binary_mask will be 1 where the gt boxes are located for bbox in gt_boxes: - unpacked_bbox = bbox[0] - bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"')) - xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"]) - width, height = int(bbox_dict["width"]), int(bbox_dict["height"]) + xmin, ymin = int(bbox[1].item()), int(bbox[2].item()) + width, height = int(bbox[3].item()), int(bbox[4].item()) binary_mask[ymin : ymin + height, xmin : xmin + width] = 1 @@ -223,7 +214,7 @@ def process_target_class( # Write metrics to csv file csv_writer.writerow( [ - names[0].split("/")[1], + names[0], iou, ioda, proportional_energy, @@ -268,19 +259,25 @@ def run( """ for samples in tqdm(data_loader, desc="batches", leave=False, disable=None): # Check if the sample has a bounding box entry - if len(samples) < 4: + if "radsign_bboxes" not in samples[1]: logger.warning( "The dataset does not contain bounding box information. No localization metrics can be calculated." ) return else: # TB negative labels are skipped (they don't have gt bboxes) - if samples[3][0] == "none": + if samples[1]["label"].item() == 0: continue - names = samples[0] + names = samples[1]["name"] - gt_bboxes = samples[3] + gt_bboxes = samples[1]["radsign_bboxes"] + + if not gt_bboxes: + logger.warning( + f'This sample does not have bounding box information. No localization metrics can be calculated. Sample "{names[0]}" is skipped.' + ) + continue for target_class_name, csv_writer in csv_writers.items(): saliency_map_path = os.path.join( diff --git a/src/ptbench/engine/visualizer.py b/src/ptbench/engine/visualizer.py index b734a547..b149f3ea 100644 --- a/src/ptbench/engine/visualizer.py +++ b/src/ptbench/engine/visualizer.py @@ -78,8 +78,8 @@ def run( ), "Output folder must not be the same as the input folder." # Check that the output folder is not a subdirectory of the input_folder - assert not output_folder.startswith( - input_folder + assert not str(output_folder).startswith( + str(input_folder) ), "Output folder must not be a subdirectory of the input folder." output_folder = os.path.abspath(output_folder) @@ -89,6 +89,8 @@ def run( target_classes = _get_target_classes_from_directory(input_folder) + shown_warnings = set() + if visualize_detected_bbox: # Load detected bounding box information localization_metric_dfs = {} @@ -113,16 +115,19 @@ def run( for samples in tqdm(data_loader, desc="batches", leave=False, disable=None): includes_bboxes = True - if len(samples) < 4 or samples[2][0].item() == 0: + if samples[1]["label"].item() == 0 or "radsign_bboxes" not in samples[1]: includes_bboxes = False + else: + gt_bboxes = samples[1]["radsign_bboxes"] + if not gt_bboxes: + includes_bboxes = False - # if visualize_groundtruth and not includes_bboxes: - # logger.warning( - # "This sample does not have bounding box information. No ground truth bounding boxes can be visualized." - # ) + names = samples[1]["name"] - names = samples[0] - gt_bboxes = samples[3] + if visualize_groundtruth and not includes_bboxes and samples[1]["label"].item() == 1: + logger.warning( + f'Sample "{names[0]}" does not have bounding box information. No ground truth bounding boxes can be visualized.' + ) for target_class in target_classes: input_base_path = ( @@ -136,7 +141,7 @@ def run( saliency_map_path = os.path.join(input_base_path, saliency_map_name) # Skip if the saliency map does not exist (e.g. if no saliency maps were generated for TB negative samples) - if not includes_bboxes and not os.path.exists(saliency_map_path): + if not os.path.exists(saliency_map_path): continue saliency_map = np.load(saliency_map_path) @@ -146,11 +151,10 @@ def run( img_with_boxes = original_img.copy() # Draw bounding boxes on the image - if includes_bboxes: - if gt_bboxes != "none" and visualize_groundtruth: - img_with_boxes = draw_boxes_on_image( - img_with_boxes, gt_bboxes - ) + if visualize_groundtruth and includes_bboxes: + img_with_boxes = draw_boxes_on_image( + img_with_boxes, gt_bboxes + ) # show_cam_on_image expects the img to be between [0, 1] rgb_img = img_with_boxes / 255.0 @@ -168,7 +172,7 @@ def run( if visualize_detected_bbox: if target_class in localization_metric_dfs: temp_df = localization_metric_dfs[target_class] - temp_filename = names[0].split("/")[1] + temp_filename = names[0] row = temp_df[temp_df["Image"] == temp_filename] @@ -186,15 +190,16 @@ def run( f"Could not find entry for {temp_filename} in .csv file." ) else: - logger.warning( - f"No localization metrics .csv file found for the {dataset_split_name} split for {target_class}." - ) + error_message = f"No localization metrics .csv file found for the {dataset_split_name} split for {target_class}." + if error_message not in shown_warnings: + logger.warning(error_message) + shown_warnings.add(error_message) # Visualize ROAD scores if road_path is not None: if target_class in road_metric_dfs: temp_df = road_metric_dfs[target_class] - temp_filename = names[0].split("/")[1] + temp_filename = names[0] row = temp_df[temp_df["Image"] == temp_filename] @@ -205,23 +210,24 @@ def run( row["Combined Score ((LeRF-MoRF) / 2)"].values[0] ) percentiles = str(row["Percentiles"].values[0]) + + visualization = visualize_road_scores( + visualization, + MoRF_score, + LeRF_score, + combined_score, + vis_directory_name, + percentiles, + ) else: logger.warning( f"Could not find entry for {temp_filename} in .csv file." ) else: - logger.warning( - f"No ROAD metrics .csv file found for the {dataset_split_name} split for {target_class}." - ) - - visualization = visualize_road_scores( - visualization, - MoRF_score, - LeRF_score, - combined_score, - vis_directory_name, - percentiles, - ) + error_message = f"No ROAD metrics .csv file found for the {dataset_split_name} split for {target_class}." + if error_message not in shown_warnings: + logger.warning(error_message) + shown_warnings.add(error_message) # Save image output_file_path = os.path.join(output_base_path, names[0]) diff --git a/src/ptbench/scripts/calculate_road.py b/src/ptbench/scripts/calculate_road.py index 41bbe3db..708e7254 100644 --- a/src/ptbench/scripts/calculate_road.py +++ b/src/ptbench/scripts/calculate_road.py @@ -194,28 +194,28 @@ def prepare_csv_writers( .. code:: sh - ptbench calculate-road -vv pasa tbx11k_simplified_bbox --device="cuda" --weight=path/to/model_final.pth --output-folder=path/to/output_folder + ptbench calculate-road -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model_final.pth --output-folder=path/to/output_folder """, ) @click.option( "--model", "-m", - help="A lightining module instance implementing the network to be trained.", + help="A lightining module instance implementing the network to be used for inference.", required=True, cls=ResourceOption, ) @click.option( "--datamodule", "-d", - help="A lighting data module containing the training and validation sets.", + help="A lighting data module containing the training, validation and test sets.", required=True, cls=ResourceOption, ) @click.option( "--output-folder", "-o", - help="Path where to store the visualizations (created if does not exist)", + help="Path where to store the ROAD scores (created if does not exist)", required=True, type=click.Path( file_okay=False, @@ -352,11 +352,9 @@ def calculate_road( logger.info(f"Calculating ROAD scores for '{k}' set...") - data_loader = v - run( model, - data_loader, + data_loader=v, output_folder=output_folder, device=device, cam=cam, diff --git a/src/ptbench/scripts/evaluate_saliencymaps.py b/src/ptbench/scripts/evaluate_saliencymaps.py index 991e9e5e..e9cc72a9 100644 --- a/src/ptbench/scripts/evaluate_saliencymaps.py +++ b/src/ptbench/scripts/evaluate_saliencymaps.py @@ -4,6 +4,7 @@ import csv import os +import pathlib import click @@ -65,18 +66,21 @@ def prepare_csv_writers(input_folder, dataset_split): .. code:: sh - ptbench evaluate-saliencymaps -vv tbx11k_simplified_bbox_rgb --input-folder=parent_folder/gradcam/ + ptbench evaluate-saliencymaps -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ """, ) @click.option( - "--dataset", + "--model", + "-m", + help="A lightining module instance implementing the network to be used for applying the necessary data transformations.", + required=True, + cls=ResourceOption, +) +@click.option( + "--datamodule", "-d", - help="A torch.utils.data.dataset.Dataset instance implementing a dataset " - "to be used for generating visualizations, possibly including all pre-processing " - "pipelines required or, optionally, a dictionary mapping string keys to " - "torch.utils.data.dataset.Dataset instances. All keys that do not start " - "with an underscore (_) will be processed.", + help="A lighting data module containing the training, validation and test sets.", required=True, cls=ResourceOption, ) @@ -85,13 +89,19 @@ def prepare_csv_writers(input_folder, dataset_split): "-i", help="Path to the folder containing the saliency maps for a specific visualization type.", required=True, + type=click.Path( + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), default="visualizations", cls=ResourceOption, - type=click.Path(), ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def evaluate_saliencymaps( - dataset, + model, + datamodule, input_folder, **_, ) -> None: @@ -102,40 +112,30 @@ def evaluate_saliencymaps( Calculates them for each target class and split of the dataset. """ - import torch + from .utils import save_sh_command + from ..engine.saliencymap_evaluator import run - from torch.utils.data import DataLoader + save_sh_command(input_folder / "command.sh") - from ..engine.saliencymap_evaluator import run + datamodule.set_chunk_size(1, 1) + datamodule.drop_incomplete_batch = False + #datamodule.cache_samples = cache_samples + #datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms - if "datadir" in dataset: - dataset = ( - dataset["dataset"] - if isinstance(dataset["dataset"], dict) - else dict(test=dataset["dataset"]) - ) - else: - dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) + datamodule.prepare_data() + datamodule.setup(stage="predict") - for k, v in dataset.items(): - if k.startswith("_"): - logger.info(f"Skipping dataset '{k}' (not to be evaluated)") - continue + dataloaders = datamodule.predict_dataloader() + for k, v in dataloaders.items(): csv_files, csv_writers = prepare_csv_writers(input_folder, k) logger.info(f"Calculating localization metrics for '{k}' set...") - data_loader = DataLoader( - dataset=v, - batch_size=1, - shuffle=False, - pin_memory=torch.cuda.is_available(), - ) - run( input_folder, - data_loader, + data_loader=v, dataset_split_name=k, csv_writers=csv_writers, ) diff --git a/src/ptbench/scripts/generate_saliencymaps.py b/src/ptbench/scripts/generate_saliencymaps.py index 331e31db..411ac0d7 100644 --- a/src/ptbench/scripts/generate_saliencymaps.py +++ b/src/ptbench/scripts/generate_saliencymaps.py @@ -126,21 +126,21 @@ def create_cam(vis_type, model, target_layers, use_cuda): .. code:: sh - ptbench generate-saliencymaps -vv densenet tbx11k_simplified_bbox_rgb --device="cuda" --weight=path/to/model_final.pth --output-folder=path/to/visualizations + ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model_final.pth --output-folder=path/to/visualizations """, ) @click.option( "--model", "-m", - help="A lightining module instance implementing the network to be trained.", + help="A lightining module instance implementing the network to be used for inference.", required=True, cls=ResourceOption, ) @click.option( "--datamodule", "-d", - help="A lighting data module containing the training and validation sets.", + help="A lighting data module containing the training, validation and test sets.", required=True, cls=ResourceOption, ) @@ -273,11 +273,9 @@ def generate_saliencymaps( for k, v in dataloaders.items(): logger.info(f"Generating saliency maps for '{k}' set...") - data_loader = v - run( model, - data_loader, + data_loader=v, output_folder=output_folder, dataset_split_name=k, device=device, diff --git a/src/ptbench/scripts/visualize.py b/src/ptbench/scripts/visualize.py index b3922581..96bf4b04 100644 --- a/src/ptbench/scripts/visualize.py +++ b/src/ptbench/scripts/visualize.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import pathlib + import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option @@ -20,18 +22,21 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - ptbench visualize -vv tbx11k_simplified_bbox_rgb --input-folder=parent_folder/gradcam/ --road-path=parent_folder/gradcam --output-folder=path/to/visualizations + ptbench visualize -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --road-path=parent_folder/gradcam/ --output-folder=path/to/visualizations """, ) @click.option( - "--dataset", + "--model", + "-m", + help="A lightining module instance implementing the network to be used for applying the necessary data transformations.", + required=True, + cls=ResourceOption, +) +@click.option( + "--datamodule", "-d", - help="A torch.utils.data.dataset.Dataset instance implementing a dataset " - "to be used for generating visualizations, possibly including all pre-processing " - "pipelines required or, optionally, a dictionary mapping string keys to " - "torch.utils.data.dataset.Dataset instances. All keys that do not start " - "with an underscore (_) will be processed.", + help="A lighting data module containing the training, validation and test sets.", required=True, cls=ResourceOption, ) @@ -40,18 +45,28 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") "-i", help="Path to the folder containing the saliency maps for a specific visualization type.", required=True, + type=click.Path( + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), default="visualizations", cls=ResourceOption, - type=click.Path(), ) @click.option( "--output-folder", "-o", - help="Path where to store the visualizations (created if does not exist)", + help="Path where to store the ROAD scores (created if does not exist)", required=True, + type=click.Path( + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), default="visualizations", cls=ResourceOption, - type=click.Path(), ) @click.option( "--visualize-groundtruth", @@ -64,7 +79,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--road-path", "-r", - help="If the path to the previously calculated road scores is provided, MoRF, LeRF, and combined ROAD scores will be visualized on each image.", + help="If the path to the previously calculated ROAD scores is provided, MoRF, LeRF, and combined ROAD scores will be visualized on each image.", required=False, default=None, cls=ResourceOption, @@ -90,7 +105,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def visualize( - dataset, + model, + datamodule, input_folder, output_folder, visualize_groundtruth, @@ -101,34 +117,29 @@ def visualize( ) -> None: """Generates heatmaps for input CXRs based on existing saliency maps.""" - import torch + from .utils import save_sh_command + from ..engine.visualizer import run - from torch.utils.data import DataLoader + save_sh_command(input_folder / "command.sh") - from ..engine.visualizer import run + datamodule.set_chunk_size(1, 1) + datamodule.drop_incomplete_batch = False + #datamodule.cache_samples = cache_samples + #datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms - if "datadir" in dataset: - img_dir = dataset["datadir"] - dataset = ( - dataset["dataset"] - if isinstance(dataset["dataset"], dict) - else dict(test=dataset["dataset"]) - ) - - for k, v in dataset.items(): - if k.startswith("_"): - logger.info(f"Skipping dataset '{k}' (not to be evaluated)") - continue - - data_loader = DataLoader( - dataset=v, - batch_size=1, - shuffle=False, - pin_memory=torch.cuda.is_available(), - ) + datamodule.prepare_data() + datamodule.setup(stage="predict") + + dataloaders = datamodule.predict_dataloader() + + for k, v in dataloaders.items(): + # Hacky way to get img_dir + img_dir = datamodule.splits[k][0][1].datadir + if img_dir: run( - data_loader, + data_loader=v, img_dir=img_dir, input_folder=input_folder, output_folder=output_folder, @@ -138,7 +149,7 @@ def visualize( visualize_detected_bbox=visualize_detected_bbox, threshold=threshold, ) - else: - logger.warning( - 'No "datadir" or "dataset_name" key found in dataset. No visualizations can be generated.' - ) + else: + logger.warning( + 'No "datadir" found in dataset. No visualizations can be generated.' + ) \ No newline at end of file diff --git a/src/ptbench/utils/cam_utils.py b/src/ptbench/utils/cam_utils.py index 4472f451..15699342 100644 --- a/src/ptbench/utils/cam_utils.py +++ b/src/ptbench/utils/cam_utils.py @@ -1,4 +1,3 @@ -import ast import glob import os import shutil @@ -98,10 +97,8 @@ def calculate_metrics_avg_for_every_class(input_folder): def draw_boxes_on_image(img_with_boxes, bboxes): for i, bbox in enumerate(bboxes): - unpacked_bbox = bbox[0] - bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"')) - xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"]) - width, height = bbox_dict["width"], bbox_dict["height"] + xmin, ymin = int(bbox[1].item()), int(bbox[2].item()) + width, height = int(bbox[3].item()), int(bbox[4].item()) xmax = int(xmin + width) ymax = int(ymin + height) cv2.rectangle( -- GitLab