diff --git a/src/ptbench/engine/road_calculator.py b/src/ptbench/engine/road_calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..44237dadd7b0020fa15d3e0779f4ce260943e97f --- /dev/null +++ b/src/ptbench/engine/road_calculator.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import os + +import numpy as np +import torch + +from pytorch_grad_cam.metrics.road import ( + ROADCombined, + ROADLeastRelevantFirstAverage, + ROADMostRelevantFirstAverage, +) +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +class SigmoidClassifierOutputTarget: + def __init__(self, category): + self.category = category + + def __call__(self, model_output): + sigmoid_output = torch.sigmoid(model_output) + if len(sigmoid_output.shape) == 1: + return sigmoid_output[self.category] + return sigmoid_output[:, self.category] + + +rs_maps = { + 0: "cardiomegaly", + 1: "emphysema", + 2: "effusion", + 3: "hernia", + 4: "infiltration", + 5: "mass", + 6: "nodule", + 7: "atelectasis", + 8: "pneumothorax", + 9: "pleural thickening", + 10: "pneumonia", + 11: "fibrosis", + 12: "edema", + 13: "consolidation", +} + + +def calculate_road_metrics( + input_image, + grayscale_cams, + model, + targets=None, + percentiles=[20, 40, 60, 80], +): + """Calculates ROAD scores by averaging the scores for different percentiles + for a single input image for a given visualization method and a given + target class.""" + cam_metric_ROADMoRF_avg = ROADMostRelevantFirstAverage( + percentiles=percentiles + ) + cam_metric_ROADLeRF_avg = ROADLeastRelevantFirstAverage( + percentiles=percentiles + ) + cam_metric_ROADCombined_avg = ROADCombined(percentiles=percentiles) + + # Calculate ROAD scores for each percentile + MoRF_scores = cam_metric_ROADMoRF_avg( + input_tensor=input_image, + cams=grayscale_cams, + model=model, + targets=targets, + ) + LeRF_scores = cam_metric_ROADLeRF_avg( + input_tensor=input_image, + cams=grayscale_cams, + model=model, + targets=targets, + ) + combined_scores = cam_metric_ROADCombined_avg( + input_tensor=input_image, + cams=grayscale_cams, + model=model, + targets=targets, + ) + + return MoRF_scores, LeRF_scores, combined_scores + + +# Helper function to calculate the ROAD scores for a single target class +# of a single input image. +def process_target_class( + model, + names, + images, + targets, + metric_targets, + cam, + csv_writer, + percentiles, +): + grayscale_cams = cam(input_tensor=images, targets=targets) + + MoRF_scores, LeRF_scores, combined_scores = calculate_road_metrics( + input_image=images, + grayscale_cams=grayscale_cams, + model=model, + targets=metric_targets, + percentiles=percentiles, + ) + + MoRF_score = MoRF_scores[0] + LeRF_score = LeRF_scores[0] + combined_score = combined_scores[0] + + # Write metrics to csv file + csv_writer.writerow( + [ + names[0].split("/")[1], + MoRF_score, + LeRF_score, + combined_score, + str(percentiles), + ] + ) + + +def run( + model, + data_loader, + output_folder, + device, + cam, + csv_writers, + target_class="highest", + tb_positive_only=True, +): + """Applies visualization techniques on input CXR, and perturbs them to + calculate ROAD scores. + + Parameters + --------- + model : py:class:`torch.nn.Module` + The model to use for visualization. + + data_loader : py:class:`torch.torch.utils.data.DataLoader` + The pytorch Dataloader used to iterate over batches. + + output_folder : str + Directory in which the results will be saved. + + dataset_split_name : str + Name of the dataset split (e.g. "train", "validation", "test"). + + device : str + A string indicating the device to use (e.g. "cpu" or "cuda"). The device can also be specified (cuda:0) + + cam : py:class: `pytorch_grad_cam.GradCAM`, `pytorch_grad_cam.ScoreCAM`, + `pytorch_grad_cam.FullGrad`, `pytorch_grad_cam.RandomCAM`, + `pytorch_grad_cam.EigenCAM`, `pytorch_grad_cam.EigenGradCAM`, + `pytorch_grad_cam.LayerCAM`, `pytorch_grad_cam.XGradCAM`, + `pytorch_grad_cam.AblationCAM`, `pytorch_grad_cam.HiResCAM`, + `pytorch_grad_cam.GradCAMElementWise`, `pytorch_grad_cam.GradCAMplusplus`, + The CAM object to use for visualization. + + visualization_types : list + Type of visualization techniques to be applied. Possible values are: + "GradCAM", "ScoreCAM", "FullGrad", "RandomCAM", "HiResCAM", "GradCAMElementWise", "GradCAMPlusPlus", "XGradCAM", "AblationCAM", + "EigenCAM", "EigenGradCAM", "LayerCAM". + + csv_writers : dict + Dictionary containing csv writer objects for each target class. + + target_class : str + (Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is default, which means only visualizations for the class with the highest activation will be generated. + + tb_positive_only : bool + If set, only TB positive samples will be visualized. + + Returns + ------- + + all_road_scores : list + All the ROAD scores associated with filename, saved as .csv. + """ + output_folder = os.path.abspath(output_folder) + + logger.info(f"Output folder: {output_folder}") + os.makedirs(output_folder, exist_ok=True) + + model_name = model.__class__.__name__ + + percentiles = [20, 40, 60, 80] + + for samples in tqdm(data_loader, desc="batches", leave=False, disable=None): + # TB negative labels are skipped + if samples[2][0].item() == 0: + if tb_positive_only: + continue + + names = samples[0] + images = samples[1].to( + device=device, non_blocking=torch.cuda.is_available() + ) + + if model_name == "DensenetRS" and target_class.lower() == "all": + for target in range(14): + targets = [ClassifierOutputTarget(target)] + metric_targets = [SigmoidClassifierOutputTarget(target)] + + csv_writer = csv_writers[rs_maps[target]] + + process_target_class( + model, + names, + images, + targets, + metric_targets, + cam, + csv_writer, + percentiles, + ) + + if model_name == "DensenetRS": + # Get the class with highest activation manually + outputs = cam.activations_and_grads(images) + target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1) + targets = [ + ClassifierOutputTarget(category) + for category in target_categories + ] + metric_targets = [ + SigmoidClassifierOutputTarget(category) + for category in target_categories + ] + else: + targets = [ClassifierOutputTarget(0)] + metric_targets = [SigmoidClassifierOutputTarget(0)] + + csv_writer = csv_writers["targeted_class"] + + process_target_class( + model, + names, + images, + targets, + metric_targets, + cam, + csv_writer, + percentiles, + ) diff --git a/src/ptbench/engine/saliencymap_evaluator.py b/src/ptbench/engine/saliencymap_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..1285dc950a20cf521a120b7f7de48e3ad8d8006a --- /dev/null +++ b/src/ptbench/engine/saliencymap_evaluator.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import ast +import logging +import os + +import cv2 +import numpy as np + +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +def _compute_max_iou_and_ioda(detected_box, gt_boxes): + """Will calculate how much of detected area lies in ground truth boxes. + + If there are multiple gt boxes, the detected area will be calculated + for each gt box separately and the gt box with the highest + intersecting part will be used for the calculation. + """ + x_L, y_L, w_L, h_L = detected_box + detected_area = w_L * h_L + if detected_area == 0: + return 0, 0 + max_intersection = 0 + max_gt_area = 0 + + for bbox in gt_boxes: + unpacked_bbox = bbox[0] + bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"')) + xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"]) + width, height = int(bbox_dict["width"]), int(bbox_dict["height"]) + + gt_area = width * height + + # Calculate intersection + xi = max(x_L, xmin) + yi = max(y_L, ymin) + wi = max(0, min(x_L + w_L, xmin + width) - xi) + hi = max(0, min(y_L + h_L, ymin + height) - yi) + + intersection = 0 + if wi > 0 and hi > 0: + intersection = wi * hi + + if intersection > max_intersection: + max_intersection = intersection + max_gt_area = gt_area + + if max_gt_area == 0 and max_intersection == 0: + # This case means no intersection was found, even though there are gt boxes + iou, ioda = 0, 0 + else: + iou = max_intersection / ( + detected_area + max_gt_area - max_intersection + ) + ioda = max_intersection / detected_area + + return iou, ioda + + +def _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes): + """Will calculate how much of detected area lies between ground truth + boxes. + + This means that if there are multiple gt boxes, the detected area + will be compared to them simultaneously (and not to each gt box + separately).) + """ + x_L, y_L, w_L, h_L = detected_box + detected_area = w_L * h_L + if detected_area == 0: + return 0, 0 + intersection = 0 + total_gt_area = 0 + + for bbox in gt_boxes: + unpacked_bbox = bbox[0] + bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"')) + xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"]) + width, height = int(bbox_dict["width"]), int(bbox_dict["height"]) + + gt_area = width * height + total_gt_area += gt_area + + # Calculate intersection + xi = max(x_L, xmin) + yi = max(y_L, ymin) + wi = max(0, min(x_L + w_L, xmin + width) - xi) + hi = max(0, min(y_L + h_L, ymin + height) - yi) + + if wi > 0 and hi > 0: + intersection += wi * hi + + iou = intersection / (detected_area + total_gt_area - intersection) + ioda = intersection / detected_area + + return iou, ioda + + +def _compute_avg_saliency_focus(gt_boxes, saliency_map): + """Will calculate how much of the ground truth bounding boxes area is + covered by the activations.""" + + binary_mask = np.zeros_like(saliency_map) + + total_gt_bbox_area = 0 + + # For each gt box, draw a binary mask + # The binary_mask will be 1 where the gt boxes are located + for bbox in gt_boxes: + unpacked_bbox = bbox[0] + bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"')) + xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"]) + width, height = int(bbox_dict["width"]), int(bbox_dict["height"]) + + binary_mask[ymin : ymin + height, xmin : xmin + width] = 1 + + total_gt_bbox_area += width * height + + multiplied_mask = binary_mask * saliency_map + numerator = np.sum(multiplied_mask) + + if total_gt_bbox_area == 0: + avg_saliency_focus = 0 + else: + avg_saliency_focus = numerator / total_gt_bbox_area + + return avg_saliency_focus + + +# Own implementation based on +# "Score-CAM: Score-Weighted Visual Explanations for Convolutional Neural Networks" by Wang et al. (2020), +# https://arxiv.org/abs/1910.01279 +def _compute_proportional_energy(gt_boxes, saliency_map): + """Will calculate how much activation lies within the ground truth boxes + compared to the total sum of the activations.""" + binary_mask = np.zeros_like(saliency_map) + + # For each gt box, draw a binary mask + # The binary_mask will be 1 where the gt boxes are located + for bbox in gt_boxes: + unpacked_bbox = bbox[0] + bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"')) + xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"]) + width, height = int(bbox_dict["width"]), int(bbox_dict["height"]) + + binary_mask[ymin : ymin + height, xmin : xmin + width] = 1 + + multiplied_mask = binary_mask * saliency_map + numerator = np.sum(multiplied_mask) + denominator = np.sum(saliency_map) + + if denominator == 0: + proportional_energy = 0 + else: + proportional_energy = numerator / denominator + + return proportional_energy + + +def calculate_localization_metrics( + saliency_map, detected_box, ground_truth_box +): + """Calculates localization metrics for a single input image for a given + visualization method.""" + + iou, ioda = _compute_max_iou_and_ioda(detected_box, ground_truth_box) + + proportional_energy = _compute_proportional_energy( + ground_truth_box, saliency_map + ) + + avg_saliency_focus = _compute_avg_saliency_focus( + ground_truth_box, saliency_map + ) + + return iou, ioda, proportional_energy, avg_saliency_focus + + +# Helper function to calculate the metrics for a single target class +# of a single input image. +def process_target_class( + names, + gt_bboxes, + saliency_map_path, + csv_writer, +): + saliency_map = np.load(saliency_map_path) + + # Calculate bounding boxes for largest connected component + # The pixel values above 20% of max value are kept in the mask to + # calculate IoU and IoDA. + # This imitates the process done by the original CAM paper: + # "Learning Deep Features for Discriminative Localization" by Zhou et al. (2015), + # https://arxiv.org/abs/1512.04150 + thresholded_mask = np.copy(saliency_map) + max_value = np.max(thresholded_mask) + threshold_value = 0.2 * max_value + thresholded_mask[thresholded_mask < threshold_value] = 0 + thresholded_mask = (thresholded_mask > 0).astype(np.uint8) + if np.any(thresholded_mask > 0): + _, label_ids, values, _ = cv2.connectedComponentsWithStats( + thresholded_mask, connectivity=8 + ) + largest_label_id = np.argmax(values[1:, cv2.CC_STAT_AREA]) + 1 + largest_label_mask = (label_ids == largest_label_id).astype(np.uint8) + x_L, y_L, w_L, h_L = cv2.boundingRect(largest_label_mask) + x_L, y_L, w_L, h_L = int(x_L), int(y_L), int(w_L), int(h_L) + else: + x_L, y_L, w_L, h_L = 0, 0, 0, 0 + + # Calculate localization metrics + iou, ioda, proportional_energy, asf = calculate_localization_metrics( + saliency_map=saliency_map, + detected_box=(x_L, y_L, w_L, h_L), + ground_truth_box=gt_bboxes, + ) + + # Write metrics to csv file + csv_writer.writerow( + [ + names[0].split("/")[1], + iou, + ioda, + proportional_energy, + asf, + x_L, + y_L, + w_L, + h_L, + ] + ) + + +def run( + input_folder, + data_loader, + dataset_split_name, + csv_writers, +): + """Applies visualization techniques on input CXR, outputs images with + overlaid heatmaps and csv files with measurements. + + Parameters + --------- + + input_folder : str + Directory in which the saliency maps are stored for a specific visualization type. + + data_loader : py:class:`torch.torch.utils.data.DataLoader` + The pytorch Dataloader used to iterate over batches. + + dataset_split_name : str + Name of the dataset split (e.g. "train", "validation", "test"). + + csv_writers : dict + Dictionary containing csv writer objects for each target class. + + Returns + ------- + + all_predictions : list + All the predictions associated with filename and ground truth, saved as .csv. + """ + for samples in tqdm(data_loader, desc="batches", leave=False, disable=None): + # Check if the sample has a bounding box entry + if len(samples) < 4: + logger.warning( + "The dataset does not contain bounding box information. No localization metrics can be calculated." + ) + return + else: + # TB negative labels are skipped (they don't have gt bboxes) + if samples[3][0] == "none": + continue + + names = samples[0] + + gt_bboxes = samples[3] + + for target_class_name, csv_writer in csv_writers.items(): + saliency_map_path = os.path.join( + input_folder, + target_class_name, + dataset_split_name, + names[0].rsplit(".", 1)[0] + ".npy", + ) + + process_target_class( + names, + gt_bboxes, + saliency_map_path, + csv_writer, + ) diff --git a/src/ptbench/engine/saliencymap_generator.py b/src/ptbench/engine/saliencymap_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..598a66da170655d8b47f9c3a9fb089576b4a298a --- /dev/null +++ b/src/ptbench/engine/saliencymap_generator.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import os + +import numpy as np +import torch + +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +rs_maps = { + 0: "cardiomegaly", + 1: "emphysema", + 2: "effusion", + 3: "hernia", + 4: "infiltration", + 5: "mass", + 6: "nodule", + 7: "atelectasis", + 8: "pneumothorax", + 9: "pleural thickening", + 10: "pneumonia", + 11: "fibrosis", + 12: "edema", + 13: "consolidation", +} + + +def _save_npy(img_name_stem, grayscale_cam, visualization_path): + image_path = f"{visualization_path}/{img_name_stem}" + os.makedirs(os.path.dirname(image_path), exist_ok=True) + + np.save(image_path, grayscale_cam) + + +# Helper function to calculate saliency maps for a single target class +# of a single input image. +def process_target_class(names, images, targets, visualization_path, cam): + grayscale_cams = cam(input_tensor=images, targets=targets) + + for i, grayscale_cam in enumerate(grayscale_cams): + img_name_stem = names[i].split(".")[0] + + _save_npy(img_name_stem, grayscale_cam, visualization_path) + + +def run( + model, + data_loader, + output_folder, + dataset_split_name, + device, + cam, + visualization_type, + target_class="highest", + tb_positive_only=True, +): + """Applies visualization techniques on input CXR, outputs pickled saliency + maps. + + Parameters + --------- + model : py:class:`torch.nn.Module` + The model to use for the saliency map calculation. + + data_loader : py:class:`torch.torch.utils.data.DataLoader` + The pytorch Dataloader used to iterate over batches. + + output_folder : str + Directory in which the results will be saved. + + dataset_split_name : str + Name of the dataset split (e.g. "train", "validation", "test"). + + device : str + A string indicating the device to use (e.g. "cpu" or "cuda"). The device can also be specified (cuda:0) + + cam : py:class: `pytorch_grad_cam.GradCAM`, `pytorch_grad_cam.ScoreCAM`, + `pytorch_grad_cam.FullGrad`, `pytorch_grad_cam.RandomCAM`, + `pytorch_grad_cam.EigenCAM`, `pytorch_grad_cam.EigenGradCAM`, + `pytorch_grad_cam.LayerCAM`, `pytorch_grad_cam.XGradCAM`, + `pytorch_grad_cam.AblationCAM`, `pytorch_grad_cam.HiResCAM`, + `pytorch_grad_cam.GradCAMElementWise`, `pytorch_grad_cam.GradCAMplusplus`, + The CAM object to use for visualization. + + visualization_types : list + Type of visualization techniques to be applied. Possible values are: + "GradCAM", "ScoreCAM", "FullGrad", "RandomCAM", "HiResCAM", "GradCAMElementWise", "GradCAMPlusPlus", "XGradCAM", "AblationCAM", + "EigenCAM", "EigenGradCAM", "LayerCAM". + + target_class : str + (Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is default, which means only saliency maps for the class with the highest activation will be generated. + + tb_positive_only : bool + If set, saliency maps will only be generated for TB positive samples. + + Returns + ------- + + all_predictions : list + All the predictions associated with filename and ground truth, saved as .csv. + """ + output_folder = os.path.abspath(output_folder) + + logger.info(f"Output folder: {output_folder}") + os.makedirs(output_folder, exist_ok=True) + + model_name = model.__class__.__name__ + + for samples in tqdm(data_loader, desc="batches", leave=False, disable=None): + # TB negative labels are skipped (they don't have bboxes) + if samples[2][0].item() == 0: + if tb_positive_only: + continue + + names = samples[0] + images = samples[1].to( + device=device, non_blocking=torch.cuda.is_available() + ) + + if model_name == "DensenetRS" and target_class.lower() == "all": + for target in range(14): + targets = [ClassifierOutputTarget(target)] + + visualization_path = f"{output_folder}/{visualization_type}/{rs_maps[target]}/{dataset_split_name}" + os.makedirs(visualization_path, exist_ok=True) + + process_target_class( + names, images, targets, visualization_path, cam + ) + + if model_name == "DensenetRS": + # Get the class with highest activation manually + outputs = cam.activations_and_grads(images) + target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1) + targets = [ + ClassifierOutputTarget(category) + for category in target_categories + ] + else: + targets = [ClassifierOutputTarget(0)] + + visualization_path = f"{output_folder}/{visualization_type}/targeted_class/{dataset_split_name}" + os.makedirs(visualization_path, exist_ok=True) + + process_target_class(names, images, targets, visualization_path, cam) diff --git a/src/ptbench/engine/visualizer.py b/src/ptbench/engine/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b734a54732cab488852176469055e0ee4617c18c --- /dev/null +++ b/src/ptbench/engine/visualizer.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import os + +import cv2 +import numpy as np +import pandas as pd + +from PIL import Image +from tqdm import tqdm + +from ..utils.cam_utils import ( + draw_boxes_on_image, + draw_largest_component_bbox_on_image, + show_cam_on_image, + visualize_road_scores, +) + +logger = logging.getLogger(__name__) + + +def _get_target_classes_from_directory(input_folder): + # Gets a list of target classes from a directory + return [ + item + for item in os.listdir(input_folder) + if os.path.isdir(os.path.join(input_folder, item)) + ] + + +def run( + data_loader, + img_dir, + input_folder, + output_folder, + dataset_split_name, + visualize_groundtruth=False, + road_path=None, + visualize_detected_bbox=True, + threshold=0.0, +): + """Overlays saliency maps on CXR to output final images with heatmaps. + + Parameters + --------- + img_dir : str + Directory containing the images from the dataset to use for visualization. + + input_folder : str + Directory in which the saliency maps are stored for a specific visualization type. + + output_folder : str + Directory in which the results will be saved. + + visualize_groundtruth : bool + If set, generate visualizations for ground truth labels. + + road_path : str + If the path to the previously calculated road scores is provided, MoRF, LeRF, and combined ROAD scores will be visualized on each image. + + visualize_detected_bbox : bool + If set, the bounding box for the largest connected component will be visualized on each image. + + threshold : float + Only CAMs above this threshold will be visualized. + + Returns + ------- + + output_folder : str + Directory containing the images overlaid with heatmaps. + """ + assert ( + input_folder != output_folder + ), "Output folder must not be the same as the input folder." + + # Check that the output folder is not a subdirectory of the input_folder + assert not output_folder.startswith( + input_folder + ), "Output folder must not be a subdirectory of the input folder." + + output_folder = os.path.abspath(output_folder) + + logger.info(f"Output folder: {output_folder}") + os.makedirs(output_folder, exist_ok=True) + + target_classes = _get_target_classes_from_directory(input_folder) + + if visualize_detected_bbox: + # Load detected bounding box information + localization_metric_dfs = {} + + for target_class in target_classes: + input_base_path = f"{input_folder}/{target_class}/{dataset_split_name}_localization_metrics.csv" + + if os.path.exists(input_base_path): + df = pd.read_csv(input_base_path) + localization_metric_dfs[target_class] = df + + if road_path is not None: + # Load road score information + road_metric_dfs = {} + + for target_class in target_classes: + road_base_path = f"{road_path}/{target_class}/{dataset_split_name}_road_metrics.csv" + + if os.path.exists(road_base_path): + df = pd.read_csv(road_base_path) + road_metric_dfs[target_class] = df + + for samples in tqdm(data_loader, desc="batches", leave=False, disable=None): + includes_bboxes = True + if len(samples) < 4 or samples[2][0].item() == 0: + includes_bboxes = False + + # if visualize_groundtruth and not includes_bboxes: + # logger.warning( + # "This sample does not have bounding box information. No ground truth bounding boxes can be visualized." + # ) + + names = samples[0] + gt_bboxes = samples[3] + + for target_class in target_classes: + input_base_path = ( + f"{input_folder}/{target_class}/{dataset_split_name}" + ) + vis_directory_name = os.path.basename(input_folder) + output_base_path = f"{output_folder}/{vis_directory_name}/{target_class}/{dataset_split_name}" + + img_path = os.path.join(img_dir, names[0]) + saliency_map_name = names[0].rsplit(".", 1)[0] + ".npy" + saliency_map_path = os.path.join(input_base_path, saliency_map_name) + + # Skip if the saliency map does not exist (e.g. if no saliency maps were generated for TB negative samples) + if not includes_bboxes and not os.path.exists(saliency_map_path): + continue + + saliency_map = np.load(saliency_map_path) + + # Image is expected to be of type float32 + original_img = np.array(Image.open(img_path), dtype=np.float32) + img_with_boxes = original_img.copy() + + # Draw bounding boxes on the image + if includes_bboxes: + if gt_bboxes != "none" and visualize_groundtruth: + img_with_boxes = draw_boxes_on_image( + img_with_boxes, gt_bboxes + ) + + # show_cam_on_image expects the img to be between [0, 1] + rgb_img = img_with_boxes / 255.0 + + # Threshold heatmap + thresholded_mask = np.copy(saliency_map) + thresholded_mask[thresholded_mask < threshold] = 0 + + # Visualize heatmap + visualization = show_cam_on_image( + rgb_img, thresholded_mask, use_rgb=False + ) + + # Visualize bounding box for largest connected component + if visualize_detected_bbox: + if target_class in localization_metric_dfs: + temp_df = localization_metric_dfs[target_class] + temp_filename = names[0].split("/")[1] + + row = temp_df[temp_df["Image"] == temp_filename] + + if not row.empty: + x_L = int(row["detected_bbox_xmin"].values[0]) + y_L = int(row["detected_bbox_ymin"].values[0]) + w_L = int(row["detected_bbox_width"].values[0]) + h_L = int(row["detected_bbox_height"].values[0]) + + visualization = draw_largest_component_bbox_on_image( + visualization, x_L, y_L, w_L, h_L + ) + else: + logger.warning( + f"Could not find entry for {temp_filename} in .csv file." + ) + else: + logger.warning( + f"No localization metrics .csv file found for the {dataset_split_name} split for {target_class}." + ) + + # Visualize ROAD scores + if road_path is not None: + if target_class in road_metric_dfs: + temp_df = road_metric_dfs[target_class] + temp_filename = names[0].split("/")[1] + + row = temp_df[temp_df["Image"] == temp_filename] + + if not row.empty: + MoRF_score = float(row["MoRF"].values[0]) + LeRF_score = float(row["LeRF"].values[0]) + combined_score = float( + row["Combined Score ((LeRF-MoRF) / 2)"].values[0] + ) + percentiles = str(row["Percentiles"].values[0]) + else: + logger.warning( + f"Could not find entry for {temp_filename} in .csv file." + ) + else: + logger.warning( + f"No ROAD metrics .csv file found for the {dataset_split_name} split for {target_class}." + ) + + visualization = visualize_road_scores( + visualization, + MoRF_score, + LeRF_score, + combined_score, + vis_directory_name, + percentiles, + ) + + # Save image + output_file_path = os.path.join(output_base_path, names[0]) + os.makedirs(os.path.dirname(output_file_path), exist_ok=True) + cv2.imwrite(output_file_path, visualization) diff --git a/src/ptbench/scripts/calculate_road.py b/src/ptbench/scripts/calculate_road.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9080956f956340cc2c22471c16fc7716cf58db --- /dev/null +++ b/src/ptbench/scripts/calculate_road.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import csv +import os + +import click + +from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.logging import setup +from pytorch_grad_cam import ( + AblationCAM, + EigenCAM, + EigenGradCAM, + FullGrad, + GradCAM, + GradCAMElementWise, + GradCAMPlusPlus, + HiResCAM, + LayerCAM, + RandomCAM, + ScoreCAM, + XGradCAM, +) + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + +allowed_visualization_types = { + "gradcam", + "scorecam", + "fullgrad", + "randomcam", + "hirescam", + "gradcamelementwise", + "gradcam++", + "gradcamplusplus", + "xgradcam", + "ablationcam", + "eigencam", + "eigengradcam", + "layercam", +} + +rs_maps = { + 0: "cardiomegaly", + 1: "emphysema", + 2: "effusion", + 3: "hernia", + 4: "infiltration", + 5: "mass", + 6: "nodule", + 7: "atelectasis", + 8: "pneumothorax", + 9: "pleural thickening", + 10: "pneumonia", + 11: "fibrosis", + 12: "edema", + 13: "consolidation", +} + + +# To ensure that the user has selected a supported visualization type +def check_vis_types(vis_types): + if isinstance(vis_types, str): + vis_types = [vis_types.lower()] + else: + vis_types = [s.lower() for s in vis_types] + + for s in vis_types: + if not isinstance(s, str): + raise click.BadParameter( + "Visualization type must be a string or a list of strings" + ) + if s not in allowed_visualization_types: + raise click.BadParameter( + "Visualization type must be one of: {}".format( + ", ".join(allowed_visualization_types) + ) + ) + return vis_types + + +# CAM factory +def create_cam(vis_type, model, target_layers, use_cuda): + if vis_type == "gradcam": + return GradCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "scorecam": + return ScoreCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "fullgrad": + return FullGrad( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "randomcam": + return RandomCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "hirescam": + return HiResCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "gradcamelementwise": + return GradCAMElementWise( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "gradcam++" or vis_type == "gradcamplusplus": + return GradCAMPlusPlus( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "xgradcam": + return XGradCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "ablationcam": + return AblationCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "eigencam": + return EigenCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "eigengradcam": + return EigenGradCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "layercam": + return LayerCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + else: + raise ValueError(f"Unsupported visualization type: {vis_type}") + + +def prepare_csv_writers( + output_folder, visualization_type, dataset_split, num_classes=14 +): + # Create a CSV file to store the performance metrics for each image, for each target class + + csv_files = {} + csv_writers = {} + + for target in range(num_classes): + directory_path = ( + f"{output_folder}/{visualization_type}/{rs_maps[target]}" + ) + os.makedirs(directory_path, exist_ok=True) + csv_files[rs_maps[target]] = open( + f"{directory_path}/{dataset_split}_road_metrics.csv", + "w", + newline="", + ) + csv_writers[rs_maps[target]] = csv.writer(csv_files[rs_maps[target]]) + csv_writers[rs_maps[target]].writerow( + [ + "Image", + "MoRF", + "LeRF", + "Combined Score ((LeRF-MoRF) / 2)", + "Percentiles", + ] + ) + + directory_path = f"{output_folder}/{visualization_type}/targeted_class" + os.makedirs(directory_path, exist_ok=True) + csv_files["targeted_class"] = open( + f"{directory_path}/{dataset_split}_road_metrics.csv", "w", newline="" + ) + csv_writers["targeted_class"] = csv.writer(csv_files["targeted_class"]) + csv_writers["targeted_class"].writerow( + [ + "Image", + "MoRF", + "LeRF", + "Combined Score ((LeRF-MoRF) / 2)", + "Percentiles", + ] + ) + + return csv_files, csv_writers + + +@click.command( + entry_point_group="ptbench.config", + cls=ConfigCommand, + epilog="""Examples: + +\b + 1. Calculates the ROAD scores for an existing dataset configuration and stores them in .csv files: + + .. code:: sh + + ptbench calculate-road -vv pasa tbx11k_simplified_bbox -a "cuda" --weight=path/to/model_final.pth --output-folder=path/to/output_folder + +""", +) +@click.option( + "--model", + "-m", + help="A torch.nn.Module instance implementing the network to be evaluated", + required=True, + cls=ResourceOption, +) +@click.option( + "--dataset", + "-d", + help="A torch.utils.data.dataset.Dataset instance implementing a dataset " + "to be used for generating visualizations, possibly including all pre-processing " + "pipelines required or, optionally, a dictionary mapping string keys to " + "torch.utils.data.dataset.Dataset instances. All keys that do not start " + "with an underscore (_) will be processed.", + required=True, + cls=ResourceOption, +) +@click.option( + "--output-folder", + "-o", + help="Path where to store the road metrics .csv files (created if does not exist)", + required=True, + default="visualizations", + cls=ResourceOption, + type=click.Path(), +) +@click.option( + "--batch-size", + "-b", + help="Number of samples in every batch (this parameter affects memory requirements for the network)", + required=True, + show_default=True, + default=1, + type=click.IntRange(min=1), + cls=ResourceOption, +) +@click.option( + "--accelerator", + "-a", + help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)', + show_default=True, + required=True, + default="cpu", + cls=ResourceOption, +) +@click.option( + "--weight", + "-w", + help="Path or URL to pretrained model file (.ckpt extension)", + required=True, + cls=ResourceOption, +) +@click.option( + "--visualization-types", + "-vt", + help="Visualization techniques to be used. Can be called multiple times with different techniques. Currently supported ones are: " + '"GradCAM", "ScoreCAM", "FullGrad", "RandomCAM", "HiResCAM", "GradCAMElementWise", "GradCAMPlusPlus", "XGradCAM", "AblationCAM", ' + '"EigenCAM", "EigenGradCAM", "LayerCAM"', + multiple=True, + default=["GradCAM"], + cls=ResourceOption, +) +@click.option( + "--target-class", + "-tc", + help='(Use only with multi-label models) Which class to target for ROAD calculation. Can be either set to "all" or "highest". "highest" is default, which means only scores for the class with the highest activation will be calculated.', + type=str, + required=False, + default="highest", + cls=ResourceOption, +) +@click.option( + "--tb-positive-only", + "-tb", + help="If set, ROAD scores will only be calculate for TB positive samples.", + is_flag=True, + default=False, + cls=ResourceOption, +) +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +def calculate_road( + model, + dataset, + output_folder, + batch_size, + accelerator, + weight, + visualization_types, + target_class, + tb_positive_only, + **_, +) -> None: + """Creates .csv files with the AOPC_ROAD scores calculated with LeRF, MoRF + and their combination. + + Calculates them for each target class and split of the dataset. + """ + + import torch + + from torch.utils.data import DataLoader + + from ..engine.road_calculator import run + + # Temporary solution due to transition to PyTorch Lightning + if accelerator.startswith("cuda") or accelerator.startswith("gpu"): + use_cuda = torch.cuda.is_available() + device = "cuda:0" if use_cuda else "cpu" + else: + use_cuda = False + device = "cpu" + + if "datadir" in dataset: + dataset = ( + dataset["dataset"] + if isinstance(dataset["dataset"], dict) + else dict(test=dataset["dataset"]) + ) + else: + dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) + + logger.info(f"Loading checkpoint from {weight}") + + # This is a temporary solution due to transition to PyTorch Lightning + # This will not be necessary for future users of this package + state_dict = torch.load(weight, map_location=torch.device("cpu")).pop( + "model" + ) + new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} + model.load_state_dict(new_state_dict) + + # This code should work for future users of this package (no guarantee) + # model = model.load_from_checkpoint(weight, strict=False) + + model.eval() + + visualization_types = check_vis_types(visualization_types) + + model_name = model.__class__.__name__ + + if model_name == "PASA": + if "fullgrad" in visualization_types: + raise ValueError( + "Fullgrad visualization is not supported for the Pasa model." + ) + target_layers = [model.fc14] # Last non-1x1 Conv2d layer + else: + target_layers = [model.model_ft.features.denseblock4.denselayer16.conv2] + + for vis_type in visualization_types: + cam = create_cam(vis_type, model, target_layers, use_cuda) + + for k, v in dataset.items(): + if k.startswith("_"): + logger.info(f"Skipping dataset '{k}' (not to be evaluated)") + continue + + if model_name == "DensenetRS" and target_class.lower() == "all": + csv_files, csv_writers = prepare_csv_writers( + output_folder, vis_type, k, num_classes=14 + ) + else: + csv_files, csv_writers = prepare_csv_writers( + output_folder, vis_type, k, num_classes=0 + ) + + logger.info(f"Calculating ROAD scores for '{k}' set...") + + data_loader = DataLoader( + dataset=v, + batch_size=batch_size, + shuffle=False, + pin_memory=torch.cuda.is_available(), + ) + + run( + model, + data_loader, + output_folder=output_folder, + device=device, + cam=cam, + csv_writers=csv_writers, + target_class=target_class, + tb_positive_only=tb_positive_only, + ) + + for csv_file in csv_files.values(): + csv_file.close() diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py index 11cba171fb6857ca3e6fe08e2c355ab12e429f0a..c9cacaa95c3108dae172e85247f698bb49bbec5f 100644 --- a/src/ptbench/scripts/cli.py +++ b/src/ptbench/scripts/cli.py @@ -7,13 +7,19 @@ import click from clapper.click import AliasedGroup from . import ( + calculate_road, + comparevis, config, database, evaluate, + evaluate_saliencymaps, + evaluatevis, experiment, + generate_saliencymaps, predict, train, train_analysis, + visualize, ) @@ -26,10 +32,16 @@ def cli(): pass +cli.add_command(calculate_road.calculate_road) +cli.add_command(comparevis.comparevis) cli.add_command(config.config) cli.add_command(database.database) cli.add_command(evaluate.evaluate) +cli.add_command(evaluate_saliencymaps.evaluate_saliencymaps) +cli.add_command(evaluatevis.evaluatevis) cli.add_command(experiment.experiment) +cli.add_command(generate_saliencymaps.generate_saliencymaps) cli.add_command(predict.predict) cli.add_command(train.train) cli.add_command(train_analysis.train_analysis) +cli.add_command(visualize.visualize) diff --git a/src/ptbench/scripts/comparevis.py b/src/ptbench/scripts/comparevis.py new file mode 100644 index 0000000000000000000000000000000000000000..08eebc129ad08b4020df4d21f7025b4d63dfc904 --- /dev/null +++ b/src/ptbench/scripts/comparevis.py @@ -0,0 +1,184 @@ +import os + +import click + +from clapper.click import verbosity_option +from clapper.logging import setup + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +def _sorting_rule(folder_name): + folder_name = os.path.basename(folder_name) + + if "gradcam" == folder_name: + return (0, folder_name) + elif "scorecam" == folder_name: + return (1, folder_name) + elif "fullgrad" == folder_name: + return (2, folder_name) + elif "randomcam" == folder_name: + return (4, folder_name) + else: # Everything else will be sorted alphabetically after fullgrad and before randomcam + return (3, folder_name) + + +@click.command( + epilog="""Examples: + +\b + 1. Compares different visualization techniques by showing their results side-by-side + (on a per image basis). The input_folder should be a parent folder containing different + folders for each visualization technique (e.g. input_folder/gradcam, input_folder/scorecam, etc.): + + .. code:: sh + + ptbench comparevis -i path/to/input_folder -o path/to/output_folder +""", +) +@click.option( + "--input-folder", + "-i", + help="Path to the parent folder of resulted visualizations from visualize command", + required=True, + type=click.Path(), +) +@click.option( + "--output-folder", + "-o", + help="Path where to store the cam comparisons (created if does not exist)", + required=True, + default="cam_comparisons", + type=click.Path(), +) +@verbosity_option(logger=logger, expose_value=False) +def comparevis(input_folder, output_folder) -> None: + """Compares multiple visualization techniques by showing their results in + one image.""" + + import os + + import cv2 + import matplotlib.pyplot as plt + + assert ( + input_folder != output_folder + ), "Output folder must not be the same as the input folder." + + # Check that the output folder is not a subdirectory of the input_folder + assert not output_folder.startswith( + input_folder + ), "Output folder must not be a subdirectory of the input folder." + + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + dir_element_names = os.listdir(input_folder) + + # Sort folders by visualization type + dir_element_names.sort(key=_sorting_rule) + + # Get the list of subfolders + vis_type_folders = [ + os.path.join(input_folder, d) + for d in dir_element_names + if os.path.isdir(os.path.join(input_folder, d)) + ] + + if not vis_type_folders: + raise ValueError("No subdirectories found in the parent folder.") + + for vis_type_folder in vis_type_folders: + target_class_folders = next(os.walk(vis_type_folder))[1] + for target_class_folder in target_class_folders: + current_target_class_dir = os.path.join( + vis_type_folder, target_class_folder + ) + dataset_type_folders = next(os.walk(current_target_class_dir))[1] + for dataset_type_folder in dataset_type_folders: + comparison_folders = [ + os.path.join( + input_folder, + vis_type, + target_class_folder, + dataset_type_folder, + ) + for vis_type in dir_element_names + ] + + valid_folders = [] + + for folder in comparison_folders: + if not os.path.exists(folder): + logger.warning(f"Folder does not exist: {folder}") + else: + valid_folders.append(folder) + + comparison_folders = valid_folders + + output_directory = os.path.join( + output_folder, target_class_folder, dataset_type_folder + ) + os.makedirs(output_directory, exist_ok=True) + + # Use a set (unordered collection of unique elements) for efficient membership tests + image_names = set(os.listdir(comparison_folders[0])) + + # Only keep image names that exist in all folders + for folder in comparison_folders[1:]: + # This is basically an intersection-check of contents of different folders + # Images that don't exist in all folders are removed from the set + image_names &= set(os.listdir(folder)) + + if not image_names: + raise ValueError("No common images found in the folders.") + + max_cols = min(5, len(comparison_folders)) + rows, cols = ( + len(comparison_folders) // max_cols, + len(comparison_folders) % max_cols, + ) + rows = rows if cols == 0 else rows + 1 + + for image_name in image_names: + fig, axs = plt.subplots( + rows, max_cols, figsize=(4 * max_cols, 4 * rows) + ) + fig.subplots_adjust( + left=0, + right=1, + bottom=0, + top=1, + wspace=0.05, + hspace=0.1, + ) + axs = axs.ravel() + for i, folder in enumerate(comparison_folders): + image_path = os.path.join(folder, image_name) + try: + img = cv2.imread(image_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + except cv2.error as e: + raise RuntimeError( + "Error reading or processing image: " + + image_path + ) from e + axs[i].imshow(img) + axs[i].set_title( + os.path.basename( + os.path.dirname(os.path.dirname(folder)) + ), + fontsize=20, + ) + axs[i].axis("off") + + # Fill the remaining columns with empty plots (white) + for i in range(len(comparison_folders), rows * max_cols): + axs[i].axis("off") + + plt.savefig( + os.path.join(output_directory, f"{image_name}"), + bbox_inches="tight", + pad_inches=0, + ) + plt.close(fig) diff --git a/src/ptbench/scripts/evaluate_saliencymaps.py b/src/ptbench/scripts/evaluate_saliencymaps.py new file mode 100644 index 0000000000000000000000000000000000000000..991e9e5e55ec7bb3a1ee92e2a22b91a8c8a3f938 --- /dev/null +++ b/src/ptbench/scripts/evaluate_saliencymaps.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import csv +import os + +import click + +from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.logging import setup + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +def _get_target_classes_from_directory(input_folder): + # Gets a list of target classes from a directory + return [ + item + for item in os.listdir(input_folder) + if os.path.isdir(os.path.join(input_folder, item)) + ] + + +def prepare_csv_writers(input_folder, dataset_split): + # Create a CSV file to store the performance metrics for each image, for each target class + + target_classes = _get_target_classes_from_directory(input_folder) + + csv_files = {} + csv_writers = {} + + for target_class in target_classes: + directory_path = os.path.join(input_folder, target_class) + os.makedirs(directory_path, exist_ok=True) + csv_file_path = os.path.join( + directory_path, f"{dataset_split}_localization_metrics.csv" + ) + csv_files[target_class] = open(csv_file_path, "w", newline="") + csv_writers[target_class] = csv.writer(csv_files[target_class]) + csv_writers[target_class].writerow( + [ + "Image", + "IoU", + "IoDA", + "Proportional Energy", + "Average Saliency Focus", + "detected_bbox_xmin", + "detected_bbox_ymin", + "detected_bbox_width", + "detected_bbox_height", + ] + ) + + return csv_files, csv_writers + + +@click.command( + entry_point_group="ptbench.config", + cls=ConfigCommand, + epilog="""Examples: + +\b + 1. Evaluate the generated saliency maps for their localization performance: + + .. code:: sh + + ptbench evaluate-saliencymaps -vv tbx11k_simplified_bbox_rgb --input-folder=parent_folder/gradcam/ + +""", +) +@click.option( + "--dataset", + "-d", + help="A torch.utils.data.dataset.Dataset instance implementing a dataset " + "to be used for generating visualizations, possibly including all pre-processing " + "pipelines required or, optionally, a dictionary mapping string keys to " + "torch.utils.data.dataset.Dataset instances. All keys that do not start " + "with an underscore (_) will be processed.", + required=True, + cls=ResourceOption, +) +@click.option( + "--input-folder", + "-i", + help="Path to the folder containing the saliency maps for a specific visualization type.", + required=True, + default="visualizations", + cls=ResourceOption, + type=click.Path(), +) +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +def evaluate_saliencymaps( + dataset, + input_folder, + **_, +) -> None: + """Creates .csv files with the IoU, IoDA, Proportional Energy, and ASF + metrics, and additionally saves the detected bounding box coordinates for + each image. + + Calculates them for each target class and split of the dataset. + """ + + import torch + + from torch.utils.data import DataLoader + + from ..engine.saliencymap_evaluator import run + + if "datadir" in dataset: + dataset = ( + dataset["dataset"] + if isinstance(dataset["dataset"], dict) + else dict(test=dataset["dataset"]) + ) + else: + dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) + + for k, v in dataset.items(): + if k.startswith("_"): + logger.info(f"Skipping dataset '{k}' (not to be evaluated)") + continue + + csv_files, csv_writers = prepare_csv_writers(input_folder, k) + + logger.info(f"Calculating localization metrics for '{k}' set...") + + data_loader = DataLoader( + dataset=v, + batch_size=1, + shuffle=False, + pin_memory=torch.cuda.is_available(), + ) + + run( + input_folder, + data_loader, + dataset_split_name=k, + csv_writers=csv_writers, + ) + + for csv_file in csv_files.values(): + csv_file.close() diff --git a/src/ptbench/scripts/evaluatevis.py b/src/ptbench/scripts/evaluatevis.py new file mode 100644 index 0000000000000000000000000000000000000000..e98dcad992960f8cb59b199b661591647412fbc1 --- /dev/null +++ b/src/ptbench/scripts/evaluatevis.py @@ -0,0 +1,35 @@ +import click + +from clapper.click import verbosity_option +from clapper.logging import setup + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +@click.command( + epilog="""Examples: + +\b + 1. Takes all the .csv files resulting from the visualize command and generates + a summary.csv file containing a summary of the results. + + .. code:: sh + + ptbench evaluatevis -i path/to/input_folder +""", +) +@click.option( + "--input-folder", + "-i", + help="Path to the folder including the .csv files with the individual scores.", + required=True, + type=click.Path(), +) +@verbosity_option(logger=logger, expose_value=False) +def evaluatevis(input_folder) -> None: + """Calculates summary statistics for the scores for one visualization + type.""" + + from ..utils.cam_utils import calculate_metrics_avg_for_every_class + + calculate_metrics_avg_for_every_class(input_folder) diff --git a/src/ptbench/scripts/generate_saliencymaps.py b/src/ptbench/scripts/generate_saliencymaps.py new file mode 100644 index 0000000000000000000000000000000000000000..ffecafe9b4f709dc00880bce80bf3f233c42e749 --- /dev/null +++ b/src/ptbench/scripts/generate_saliencymaps.py @@ -0,0 +1,306 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import click + +from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.logging import setup +from pytorch_grad_cam import ( + AblationCAM, + EigenCAM, + EigenGradCAM, + FullGrad, + GradCAM, + GradCAMElementWise, + GradCAMPlusPlus, + HiResCAM, + LayerCAM, + RandomCAM, + ScoreCAM, + XGradCAM, +) + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + +allowed_visualization_types = { + "gradcam", + "scorecam", + "fullgrad", + "randomcam", + "hirescam", + "gradcamelementwise", + "gradcam++", + "gradcamplusplus", + "xgradcam", + "ablationcam", + "eigencam", + "eigengradcam", + "layercam", +} + + +# To ensure that the user has selected a supported visualization type +def check_vis_types(vis_types): + if isinstance(vis_types, str): + vis_types = [vis_types.lower()] + else: + vis_types = [s.lower() for s in vis_types] + + for s in vis_types: + if not isinstance(s, str): + raise click.BadParameter( + "Visualization type must be a string or a list of strings" + ) + if s not in allowed_visualization_types: + raise click.BadParameter( + "Visualization type must be one of: {}".format( + ", ".join(allowed_visualization_types) + ) + ) + return vis_types + + +# CAM factory +def create_cam(vis_type, model, target_layers, use_cuda): + if vis_type == "gradcam": + return GradCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "scorecam": + return ScoreCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "fullgrad": + return FullGrad( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "randomcam": + return RandomCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "hirescam": + return HiResCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "gradcamelementwise": + return GradCAMElementWise( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "gradcam++" or vis_type == "gradcamplusplus": + return GradCAMPlusPlus( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "xgradcam": + return XGradCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "ablationcam": + return AblationCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "eigencam": + return EigenCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "eigengradcam": + return EigenGradCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + elif vis_type == "layercam": + return LayerCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + else: + raise ValueError(f"Unsupported visualization type: {vis_type}") + + +@click.command( + entry_point_group="ptbench.config", + cls=ConfigCommand, + epilog="""Examples: + +\b + 1. Generates saliency maps and saves them as pickeled objects: + + .. code:: sh + + ptbench generate-saliencymaps -vv densenet tbx11k_simplified_bbox_rgb --accelerator="cuda" --weight=path/to/model_final.pth --output-folder=path/to/visualizations + +""", +) +@click.option( + "--model", + "-m", + help="A torch.nn.Module instance implementing the network to be evaluated", + required=True, + cls=ResourceOption, +) +@click.option( + "--dataset", + "-d", + help="A torch.utils.data.dataset.Dataset instance implementing a dataset " + "to be used for generating visualizations, possibly including all pre-processing " + "pipelines required or, optionally, a dictionary mapping string keys to " + "torch.utils.data.dataset.Dataset instances. All keys that do not start " + "with an underscore (_) will be processed.", + required=True, + cls=ResourceOption, +) +@click.option( + "--output-folder", + "-o", + help="Path where to store the visualizations (created if does not exist)", + required=True, + default="visualizations", + cls=ResourceOption, + type=click.Path(), +) +@click.option( + "--batch-size", + "-b", + help="Number of samples in every batch (this parameter affects memory requirements for the network)", + required=True, + show_default=True, + default=1, + type=click.IntRange(min=1), + cls=ResourceOption, +) +@click.option( + "--accelerator", + "-a", + help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)', + show_default=True, + required=True, + default="cpu", + cls=ResourceOption, +) +@click.option( + "--weight", + "-w", + help="Path or URL to pretrained model file (.ckpt extension)", + required=True, + cls=ResourceOption, +) +@click.option( + "--visualization-types", + "-vt", + help="Visualization techniques to be used. Can be called multiple times with different techniques. Currently supported ones are: " + '"GradCAM", "ScoreCAM", "FullGrad", "RandomCAM", "HiResCAM", "GradCAMElementWise", "GradCAMPlusPlus", "XGradCAM", "AblationCAM", ' + '"EigenCAM", "EigenGradCAM", "LayerCAM"', + multiple=True, + default=["GradCAM"], + cls=ResourceOption, +) +@click.option( + "--target-class", + "-tc", + help='(Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is default, which means only saliency maps for the class with the highest activation will be generated.', + type=str, + required=False, + default="highest", + cls=ResourceOption, +) +@click.option( + "--tb-positive-only", + "-tb", + help="If set, saliency maps will only be generated for TB positive samples.", + is_flag=True, + default=False, + cls=ResourceOption, +) +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +def generate_saliencymaps( + model, + dataset, + output_folder, + batch_size, + accelerator, + weight, + visualization_types, + target_class, + tb_positive_only, + **_, +) -> None: + """Generates saliency maps for locations with aTB for input CXRs, depending + on visualization technique and model.""" + + import torch + + from torch.utils.data import DataLoader + + from ..engine.saliencymap_generator import run + + # Temporary solution due to transition to PyTorch Lightning + if accelerator.startswith("cuda") or accelerator.startswith("gpu"): + use_cuda = torch.cuda.is_available() + device = "cuda:0" if use_cuda else "cpu" + else: + use_cuda = False + device = "cpu" + + if "datadir" in dataset: + dataset = ( + dataset["dataset"] + if isinstance(dataset["dataset"], dict) + else dict(test=dataset["dataset"]) + ) + else: + dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) + + logger.info(f"Loading checkpoint from {weight}") + + # This is a temporary solution due to transition to PyTorch Lightning + # This will not be necessary for future users of this package + state_dict = torch.load(weight, map_location=torch.device("cpu")).pop( + "model" + ) + new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} + model.load_state_dict(new_state_dict) + + # This code should work for future users of this package (no guarantee) + # model = model.load_from_checkpoint(weight, strict=False) + + model.eval() + + visualization_types = check_vis_types(visualization_types) + + model_name = model.__class__.__name__ + + if model_name == "PASA": + if "fullgrad" in visualization_types: + raise ValueError( + "Fullgrad visualization is not supported for the Pasa model." + ) + target_layers = [model.fc14] # Last non-1x1 Conv2d layer + else: + target_layers = [model.model_ft.features.denseblock4.denselayer16.conv2] + + for vis_type in visualization_types: + cam = create_cam(vis_type, model, target_layers, use_cuda) + + for k, v in dataset.items(): + if k.startswith("_"): + logger.info(f"Skipping dataset '{k}' (not to be evaluated)") + continue + + logger.info(f"Generating saliency maps for '{k}' set...") + + data_loader = DataLoader( + dataset=v, + batch_size=batch_size, + shuffle=False, + pin_memory=torch.cuda.is_available(), + ) + + run( + model, + data_loader, + output_folder=output_folder, + dataset_split_name=k, + device=device, + cam=cam, + visualization_type=vis_type, + target_class=target_class, + tb_positive_only=tb_positive_only, + ) diff --git a/src/ptbench/scripts/visualize.py b/src/ptbench/scripts/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..b3922581e6d876d0992ee7c17cd02f1f1db4f9db --- /dev/null +++ b/src/ptbench/scripts/visualize.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import click + +from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.logging import setup + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +@click.command( + entry_point_group="ptbench.config", + cls=ConfigCommand, + epilog="""Examples: + +\b + 1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration: + + .. code:: sh + + ptbench visualize -vv tbx11k_simplified_bbox_rgb --input-folder=parent_folder/gradcam/ --road-path=parent_folder/gradcam --output-folder=path/to/visualizations + +""", +) +@click.option( + "--dataset", + "-d", + help="A torch.utils.data.dataset.Dataset instance implementing a dataset " + "to be used for generating visualizations, possibly including all pre-processing " + "pipelines required or, optionally, a dictionary mapping string keys to " + "torch.utils.data.dataset.Dataset instances. All keys that do not start " + "with an underscore (_) will be processed.", + required=True, + cls=ResourceOption, +) +@click.option( + "--input-folder", + "-i", + help="Path to the folder containing the saliency maps for a specific visualization type.", + required=True, + default="visualizations", + cls=ResourceOption, + type=click.Path(), +) +@click.option( + "--output-folder", + "-o", + help="Path where to store the visualizations (created if does not exist)", + required=True, + default="visualizations", + cls=ResourceOption, + type=click.Path(), +) +@click.option( + "--visualize-groundtruth", + "-vgt", + help="If set, visualizations for ground truth labels will be generated. Only works for datasets with bounding boxes.", + is_flag=True, + default=False, + cls=ResourceOption, +) +@click.option( + "--road-path", + "-r", + help="If the path to the previously calculated road scores is provided, MoRF, LeRF, and combined ROAD scores will be visualized on each image.", + required=False, + default=None, + cls=ResourceOption, + type=click.Path(), +) +@click.option( + "--visualize-detected-bbox", + "-vb", + help="If set, largest component bounding boxes will be visualized on each image. Only works if the bounding boxes have been previously generated.", + is_flag=True, + default=False, + cls=ResourceOption, +) +@click.option( + "--threshold", + "-t", + help="Only activations above this threshold will be visualized.", + show_default=True, + required=True, + default=0.0, + type=click.FloatRange(min=0, max=1), + cls=ResourceOption, +) +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +def visualize( + dataset, + input_folder, + output_folder, + visualize_groundtruth, + road_path, + visualize_detected_bbox, + threshold, + **_, +) -> None: + """Generates heatmaps for input CXRs based on existing saliency maps.""" + + import torch + + from torch.utils.data import DataLoader + + from ..engine.visualizer import run + + if "datadir" in dataset: + img_dir = dataset["datadir"] + dataset = ( + dataset["dataset"] + if isinstance(dataset["dataset"], dict) + else dict(test=dataset["dataset"]) + ) + + for k, v in dataset.items(): + if k.startswith("_"): + logger.info(f"Skipping dataset '{k}' (not to be evaluated)") + continue + + data_loader = DataLoader( + dataset=v, + batch_size=1, + shuffle=False, + pin_memory=torch.cuda.is_available(), + ) + + run( + data_loader, + img_dir=img_dir, + input_folder=input_folder, + output_folder=output_folder, + dataset_split_name=k, + visualize_groundtruth=visualize_groundtruth, + road_path=road_path, + visualize_detected_bbox=visualize_detected_bbox, + threshold=threshold, + ) + else: + logger.warning( + 'No "datadir" or "dataset_name" key found in dataset. No visualizations can be generated.' + ) diff --git a/src/ptbench/utils/cam_utils.py b/src/ptbench/utils/cam_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4472f4515b5b0233865c2c5ba79c51520b659e0c --- /dev/null +++ b/src/ptbench/utils/cam_utils.py @@ -0,0 +1,291 @@ +import ast +import glob +import os +import shutil + +import cv2 +import numpy as np +import pandas as pd + + +def _calculate_stats_over_dataset(file_path): + data = pd.read_csv(file_path) + + # Compute mean, median and standard deviation for each score type + metrics = {} + for column in [ + "MoRF", + "LeRF", + "Combined Score ((LeRF-MoRF) / 2)", + "IoU", + "IoDA", + "propEnergy", + "ASF", + ]: + mean = round(data[column].mean(), 3) + median = round(data[column].median(), 3) + std_dev = round(data[column].std(), 3) + q1 = round(data[column].quantile(0.25), 3) + q3 = round(data[column].quantile(0.75), 3) + + metrics[column] = { + "mean": mean, + "median": median, + "std_dev": std_dev, + "Q1": q1, + "Q3": q3, + } + + return metrics + + +def calculate_metrics_avg_for_every_class(input_folder): + directories = [ + dir + for dir in os.listdir(input_folder) + if os.path.isdir(os.path.join(input_folder, dir)) + ] + + # Get unique file names ending with .csv without the extension + csv_file_names = set() + for directory in directories: + dir_path = os.path.join(input_folder, directory) + csv_files_in_dir = glob.glob(os.path.join(dir_path, "*.csv")) + for csv_file in csv_files_in_dir: + csv_file_name = os.path.splitext(os.path.basename(csv_file))[0] + csv_file_names.add(csv_file_name) + + all_results = {file_name: [] for file_name in csv_file_names} + + for file_name in csv_file_names: + previous_summary = os.path.join( + input_folder, f"{file_name}_summary.csv" + ) + if os.path.exists(previous_summary): + backup = previous_summary + "~" + if os.path.exists(backup): + os.unlink(backup) + shutil.move(previous_summary, backup) + + for directory in directories: + dir_path = os.path.join(input_folder, directory) + + for file_name in csv_file_names: + csv_file = os.path.join(dir_path, f"{file_name}.csv") + + if os.path.exists(csv_file): + class_name = directory + metrics = _calculate_stats_over_dataset(csv_file) + + for score_type, values in metrics.items(): + result = { + "Class Name": class_name, + "Score Type": score_type, + "Mean": values["mean"], + "Standard Deviation": values["std_dev"], + "Median": values["median"], + "Q1": values["Q1"], + "Q3": values["Q3"], + } + all_results[file_name].append(result) + + for file_name, results in all_results.items(): + df_results = pd.DataFrame(results) + df_results.to_csv( + os.path.join(input_folder, f"{file_name}_summary.csv"), index=False + ) + + +def draw_boxes_on_image(img_with_boxes, bboxes): + for i, bbox in enumerate(bboxes): + unpacked_bbox = bbox[0] + bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"')) + xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"]) + width, height = bbox_dict["width"], bbox_dict["height"] + xmax = int(xmin + width) + ymax = int(ymin + height) + cv2.rectangle( + img_with_boxes, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2 + ) + + # Add text to the bbox + text = f"GTruth bbox {i+1}" + font_scale = max( + 0.5, min(width, height) / 200.0 + ) # Ensure minimum font scale of 0.5 + font_scale = min(font_scale, 1.3) # Ensure maximum font scale of 1.3 + + # Computes text height + (text_width, text_height), baseline = cv2.getTextSize( + text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 2 + ) + + # Make sure not to draw text outside of image + if ymin + height + text_height + 2 < img_with_boxes.shape[0]: + y = ymin + height + text_height + 2 # put below box + else: + y = ymin - text_height - 2 # put above box + + y = int(y) + + cv2.putText( + img_with_boxes, + text, + (xmin + 5, y), + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + (36, 255, 12), + 2, + ) + return img_with_boxes + + +def draw_largest_component_bbox_on_image(img_with_boxes, x, y, w, h): + cv2.rectangle(img_with_boxes, (x, y), (x + w, y + h), (0, 0, 255), 2) + + # Add text to the bbox + text = "Detected Area" + font_scale = max(0.5, min(w, h) / 200.0) # Ensure minimum font scale of 0.5 + font_scale = min(font_scale, 1.3) # Ensure maximum font scale of 1.3 + + # Computes text height + (text_width, text_height), baseline = cv2.getTextSize( + text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 2 + ) + + # Make sure not to draw text outside bottom of image + if y + h + text_height + 2 < img_with_boxes.shape[0]: + y = y + h + text_height + 2 # put below box + else: + y = y - text_height - 2 # put above box + + y = int(y) + + cv2.putText( + img_with_boxes, + text, + (x + 5, y), + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + (0, 0, 255), + 2, + ) + return img_with_boxes + + +def show_cam_on_image( + img: np.ndarray, + thresholded_mask: np.ndarray, + use_rgb: bool = False, + colormap: int = cv2.COLORMAP_JET, + image_weight: float = 0.5, +) -> np.ndarray: + """This function overlays the cam mask on the image as an heatmap. By + default the heatmap is in BGR format. + + :param img: The base image in RGB or BGR format. + :param mask: The cam mask. + :param threshold: Only apply heatmap to areas where cam is above this threshold. + :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. + :param colormap: The OpenCV colormap to be used. + :param image_weight: The final result is image_weight * img + (1-image_weight) * mask. + :returns: The default image with the cam overlay. + + This is a slightly modified version of the show_cam_on_image implementation in: + https://github.com/jacobgil/pytorch-grad-cam + """ + + if img.shape[:2] != thresholded_mask.shape: + raise ValueError( + "The shape of the mask should be the same as the shape of the image." + ) + + heatmap = cv2.applyColorMap(np.uint8(255 * thresholded_mask), colormap) + if use_rgb: + heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) + heatmap = np.float32(heatmap) / 255 + + if np.max(img) > 1: + raise Exception("The input image should np.float32 in the range [0, 1]") + + if image_weight < 0 or image_weight > 1: + raise Exception( + f"image_weight should be in the range [0, 1].\ + Got: {image_weight}" + ) + + # For pixels where the mask is zero, + # the original image pixels are being used without a mask. + cam = np.where( + thresholded_mask[..., np.newaxis] == 0, + img, + (1 - image_weight) * heatmap + image_weight * img, + ) + cam = cam / np.max(cam) + return np.uint8(255 * cam) + + +def visualize_road_scores( + visualization, MoRF_score, LeRF_score, combined_score, name, percentiles +): + visualization = cv2.putText( + visualization, + name, + (10, 20), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + cv2.LINE_AA, + ) + visualization = cv2.putText( + visualization, + f"Percentiles: {percentiles}", + (10, 40), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + cv2.LINE_AA, + ) + visualization = cv2.putText( + visualization, + "Remove and Debias", + (10, 55), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + cv2.LINE_AA, + ) + visualization = cv2.putText( + visualization, + f"MoRF score: {MoRF_score:.5f}", + (10, 70), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + cv2.LINE_AA, + ) + visualization = cv2.putText( + visualization, + f"LeRF score: {LeRF_score:.5f}", + (10, 85), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + cv2.LINE_AA, + ) + visualization = cv2.putText( + visualization, + f"(LeRF-MoRF)/2: {combined_score:.5f}", + (10, 100), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + cv2.LINE_AA, + ) + return visualization diff --git a/tests/data/test_vis_metrics.csv b/tests/data/test_vis_metrics.csv new file mode 100644 index 0000000000000000000000000000000000000000..1a144f21b6f21693d56be19ef301c240f30398e3 --- /dev/null +++ b/tests/data/test_vis_metrics.csv @@ -0,0 +1,6 @@ +Image,MoRF,LeRF,Combined Score ((LeRF-MoRF) / 2),IoU,IoDA,propEnergy,ASF +tb0004.png,1,2,3,4,5,6,7 +tb0006.png,2,3,4,5,6,7,8 +tb0009.png,1,2,3,4,5,6,7 +tb0014.png,2,3,4,5,6,7,8 +tb0015.png,1,2,3,4,5,6,7 diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/test/tb0004.png b/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/test/tb0004.png new file mode 100644 index 0000000000000000000000000000000000000000..1fc4b4d4293fa401d1bc4bb3e99bcc0083653965 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/test/tb0004.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/train/tb0005.png b/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/train/tb0005.png new file mode 100644 index 0000000000000000000000000000000000000000..1cdd8a5fc1959009da23a9b5e25c4cf09f344532 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/train/tb0005.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/test/tb0004.png b/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/test/tb0004.png new file mode 100644 index 0000000000000000000000000000000000000000..92cb8bc6e0c454cc38d26f61d1c31420f14985eb Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/test/tb0004.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/train/tb0005.png b/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/train/tb0005.png new file mode 100644 index 0000000000000000000000000000000000000000..3a6aab0ebed12caca31920bde6a1bb1ae1cad7d0 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/train/tb0005.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/test/tb0004.png b/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/test/tb0004.png new file mode 100644 index 0000000000000000000000000000000000000000..c9cc1d12680c583a636aa60c3c28f8dd71d557ad Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/test/tb0004.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/train/tb0005.png b/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/train/tb0005.png new file mode 100644 index 0000000000000000000000000000000000000000..80fe19a9fb3a3700b787a83a2446083de6783c47 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/train/tb0005.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/test/tb0004.png b/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/test/tb0004.png new file mode 100644 index 0000000000000000000000000000000000000000..abd1639622163c5a0d5ef8d32ff557ef157f63f0 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/test/tb0004.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/train/tb0005.png b/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/train/tb0005.png new file mode 100644 index 0000000000000000000000000000000000000000..6fae53802b4640edaeb106deab6252920b860abd Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/train/tb0005.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/test/tb0004.png b/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/test/tb0004.png new file mode 100644 index 0000000000000000000000000000000000000000..4819323f21e673c959edff58e4e34e1b82507169 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/test/tb0004.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/train/tb0005.png b/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/train/tb0005.png new file mode 100644 index 0000000000000000000000000000000000000000..8c10e34985fd278b5090a63033ae1fcb0d4c3903 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/train/tb0005.png differ diff --git a/tests/test_11k.py b/tests/test_11k.py new file mode 100644 index 0000000000000000000000000000000000000000..2bba74dcfbfb056987a9c5ccb26f7280e0938248 --- /dev/null +++ b/tests/test_11k.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for TBX11K simplified dataset split 1.""" + +import pytest + + +def test_protocol_consistency(): + from ptbench.data.tbx11k_simplified import dataset + + # Default protocol + subset = dataset.subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 2767 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 706 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 957 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 0-9 + for f in range(10): + subset = dataset.subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 3177 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 810 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 443 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + +def test_protocol_consistency_bbox(): + from ptbench.data.tbx11k_simplified import dataset_with_bboxes + + # Default protocol + subset = dataset_with_bboxes["dataset"].subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 2767 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 706 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 957 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + + # Cross-validation fold 0-9 + for f in range(10): + subset = dataset_with_bboxes["dataset"].subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 3177 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 810 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 443 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_loading(): + from ptbench.data.tbx11k_simplified import dataset + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 2 + + assert "data" in data + assert data["data"].size == (512, 512) + + assert data["data"].mode == "L" # Check colors + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset.subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_loading_bbox(): + from ptbench.data.tbx11k_simplified import dataset_with_bboxes + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 3 + + assert "data" in data + assert data["data"].size == (512, 512) + + assert data["data"].mode == "L" # Check colors + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + assert "bboxes" in data + assert data["bboxes"] == "none" or data["bboxes"][0].startswith( + "{'xmin':" + ) + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset_with_bboxes["dataset"].subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_check(): + from ptbench.data.tbx11k_simplified import dataset + + assert dataset.check() == 0 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_check_bbox(): + from ptbench.data.tbx11k_simplified import dataset_with_bboxes + + assert dataset_with_bboxes["dataset"].check() == 0 diff --git a/tests/test_11k_RS.py b/tests/test_11k_RS.py new file mode 100644 index 0000000000000000000000000000000000000000..601bbc4628ea752f3ad52b78cedecd64a4b215dc --- /dev/null +++ b/tests/test_11k_RS.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for Extended TBX11K simplified dataset split 1.""" + +import pytest + + +def test_protocol_consistency(): + from ptbench.data.tbx11k_simplified_RS import dataset + + # Default protocol + subset = dataset.subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 2767 + + assert "validation" in subset + assert len(subset["validation"]) == 706 + + assert "test" in subset + assert len(subset["test"]) == 957 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 0-9 + for f in range(10): + subset = dataset.subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 3177 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 810 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 443 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_loading(): + from ptbench.data.tbx11k_simplified_RS import dataset + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 2 + + assert "data" in data + assert len(data["data"]) == 14 # Check radiological signs + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset.subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) diff --git a/tests/test_11k_v2.py b/tests/test_11k_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..12662886ed4eea1a2fa654c80b9666c53e5af515 --- /dev/null +++ b/tests/test_11k_v2.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for TBX11K simplified dataset split 2.""" + +import pytest + + +def test_protocol_consistency(): + from ptbench.data.tbx11k_simplified_v2 import dataset + + # Default protocol + subset = dataset.subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 5241 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1335 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 1793 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 0-8 + for f in range(9): + subset = dataset.subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1529 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 837 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 9 + subset = dataset.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + +def test_protocol_consistency_bbox(): + from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes + + # Default protocol + subset = dataset_with_bboxes.subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 5241 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1335 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 1793 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + + # Cross-validation fold 0-8 + for f in range(9): + subset = dataset_with_bboxes.subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1529 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 837 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + + # Cross-validation fold 9 + subset = dataset_with_bboxes.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Check bounding boxes + for s in subset["train"]: + assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") +def test_loading(): + from ptbench.data.tbx11k_simplified_v2 import dataset + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 2 + + assert "data" in data + assert data["data"].size == (512, 512) + + assert data["data"].mode == "L" # Check colors + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset.subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") +def test_loading_bbox(): + from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 3 + + assert "data" in data + assert data["data"].size == (512, 512) + + assert data["data"].mode == "L" # Check colors + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + assert "bboxes" in data + assert data["bboxes"] == "none" or data["bboxes"][0].startswith( + "{'xmin':" + ) + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset_with_bboxes.subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") +def test_check(): + from ptbench.data.tbx11k_simplified_v2 import dataset + + assert dataset.check() == 0 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") +def test_check_bbox(): + from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes + + assert dataset_with_bboxes.check() == 0 diff --git a/tests/test_11k_v2_RS.py b/tests/test_11k_v2_RS.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ac2464324aee1aa45e185c13380e301a949597 --- /dev/null +++ b/tests/test_11k_v2_RS.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for Extended TBX11K simplified dataset split 2.""" + +import pytest + + +def test_protocol_consistency(): + from ptbench.data.tbx11k_simplified_v2_RS import dataset + + # Default protocol + subset = dataset.subsets("default") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 5241 + + assert "validation" in subset + assert len(subset["validation"]) == 1335 + + assert "test" in subset + assert len(subset["test"]) == 1793 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 0-8 + for f in range(9): + subset = dataset.subsets("fold_" + str(f)) + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1529 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 837 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + # Cross-validation fold 9 + subset = dataset.subsets("fold_9") + assert len(subset) == 3 + + assert "train" in subset + assert len(subset["train"]) == 6003 + for s in subset["train"]: + assert s.key.startswith("images/") + + assert "validation" in subset + assert len(subset["validation"]) == 1530 + for s in subset["validation"]: + assert s.key.startswith("images/") + + assert "test" in subset + assert len(subset["test"]) == 836 + for s in subset["test"]: + assert s.key.startswith("images/") + + # Check labels + for s in subset["train"]: + assert s.label in [0.0, 1.0] + + for s in subset["validation"]: + assert s.label in [0.0, 1.0] + + for s in subset["test"]: + assert s.label in [0.0, 1.0] + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") +def test_loading(): + from ptbench.data.tbx11k_simplified_v2_RS import dataset + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + assert len(data) == 2 + + assert "data" in data + assert len(data["data"]) == 14 # Check radiological signs + + assert "label" in data + assert data["label"] in [0, 1] # Check labels + + limit = 30 # use this to limit testing to first images only, else None + + subset = dataset.subsets("default") + for s in subset["train"][:limit]: + _check_sample(s) diff --git a/tests/test_cam_utils.py b/tests/test_cam_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..328be392ba52d7453a1f17e9265e91b49a4c6673 --- /dev/null +++ b/tests/test_cam_utils.py @@ -0,0 +1,391 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for the cam_utils script.""" + +import cv2 +import numpy as np +import pandas as pd +import pytest + +from ptbench.utils.cam_utils import ( + _calculate_stats_over_dataset, + calculate_metrics_avg_for_every_class, + draw_boxes_on_image, + draw_largest_component_bbox_on_image, + show_cam_on_image, + visualize_road_scores, +) + + +def test_calculate_stats_over_dataset(datadir): + # Sample CSV + metrics = _calculate_stats_over_dataset(datadir / "test_vis_metrics.csv") + + expected_metrics = { + "MoRF": { + "mean": 1.4, + "median": 1.0, + "std_dev": 0.548, + "Q1": 1.0, + "Q3": 2.0, + }, + "LeRF": { + "mean": 2.4, + "median": 2.0, + "std_dev": 0.548, + "Q1": 2.0, + "Q3": 3.0, + }, + "Combined Score ((LeRF-MoRF) / 2)": { + "mean": 3.4, + "median": 3.0, + "std_dev": 0.548, + "Q1": 3.0, + "Q3": 4.0, + }, + "IoU": { + "mean": 4.4, + "median": 4.0, + "std_dev": 0.548, + "Q1": 4.0, + "Q3": 5.0, + }, + "IoDA": { + "mean": 5.4, + "median": 5.0, + "std_dev": 0.548, + "Q1": 5.0, + "Q3": 6.0, + }, + "propEnergy": { + "mean": 6.4, + "median": 6.0, + "std_dev": 0.548, + "Q1": 6.0, + "Q3": 7.0, + }, + "ASF": { + "mean": 7.4, + "median": 7.0, + "std_dev": 0.548, + "Q1": 7.0, + "Q3": 8.0, + }, + } + for column, expected_values in expected_metrics.items(): + for key, expected_value in expected_values.items(): + assert metrics[column][key] == expected_value + + +def test_calculate_metrics_avg_for_every_class(temporary_basedir): + # Create a sample directory structure and CSV files + input_folder = temporary_basedir / "camutils" / "gradcam" + input_folder.mkdir(parents=True, exist_ok=True) + class1_dir = input_folder / "class1" + class1_dir.mkdir(parents=True, exist_ok=True) + class2_dir = input_folder / "class2" + class2_dir.mkdir(parents=True, exist_ok=True) + + data = { + "MoRF": [1, 2, 3], + "LeRF": [2, 4, 6], + "Combined Score ((LeRF-MoRF) / 2)": [1.5, 3, 4.5], + "IoU": [1, 2, 3], + "IoDA": [2, 4, 6], + "propEnergy": [1.5, 3, 4.5], + "ASF": [1, 2, 3], + } + df = pd.DataFrame(data) + df.to_csv(class1_dir / "file1.csv", index=False) + df.to_csv(class2_dir / "file1.csv", index=False) + df.to_csv(class1_dir / "file2.csv", index=False) + df.to_csv(class2_dir / "file2.csv", index=False) + + calculate_metrics_avg_for_every_class(input_folder) + + # Assert the existence of summary files + assert (input_folder / "file1_summary.csv").exists() + assert (input_folder / "file2_summary.csv").exists() + + # Assert the content of summary files + score_types = [ + "MoRF", + "LeRF", + "Combined Score ((LeRF-MoRF) / 2)", + "IoU", + "IoDA", + "propEnergy", + "ASF", + ] + expected_data = { + "Class Name": ["class1"] * len(score_types) + + ["class2"] * len(score_types), + "Score Type": score_types * 2, + "Mean": [ + 2.0, + 4.0, + 3.0, + 2.0, + 4.0, + 3.0, + 2.0, + 2.0, + 4.0, + 3.0, + 2.0, + 4.0, + 3.0, + 2.0, + ], + "Standard Deviation": [ + 1.0, + 2.0, + 1.5, + 1.0, + 2.0, + 1.5, + 1.0, + 1.0, + 2.0, + 1.5, + 1.0, + 2.0, + 1.5, + 1.0, + ], + "Median": [ + 2.0, + 4.0, + 3.0, + 2.0, + 4.0, + 3.0, + 2.0, + 2.0, + 4.0, + 3.0, + 2.0, + 4.0, + 3.0, + 2.0, + ], + "Q1": [ + 1.5, + 3.0, + 2.25, + 1.5, + 3.0, + 2.25, + 1.5, + 1.5, + 3.0, + 2.25, + 1.5, + 3.0, + 2.25, + 1.5, + ], + "Q3": [ + 2.5, + 5.0, + 3.75, + 2.5, + 5.0, + 3.75, + 2.5, + 2.5, + 5.0, + 3.75, + 2.5, + 5.0, + 3.75, + 2.5, + ], + } + for file_name in ["file1", "file2"]: + df = pd.read_csv(input_folder / f"{file_name}_summary.csv") + for column, expected_values in expected_data.items(): + assert list(df[column]) == expected_values + + +def test_draw_boxes_on_image(): + # Create a sample image and bounding boxes + img_with_boxes = np.zeros((200, 200, 3), dtype=np.uint8) + img_with_boxes2 = np.zeros((256, 256, 3), dtype=np.uint8) + img_with_boxes3 = np.zeros((512, 512, 3), dtype=np.uint8) + bboxes = [ + ['{"xmin": 10, "ymin": 10, "width": 20, "height": 20}'], + ['{"xmin": 30, "ymin": 30, "width": 40, "height": 40}'], + ] + + result = draw_boxes_on_image(img_with_boxes, bboxes) + result2 = draw_boxes_on_image(img_with_boxes2, bboxes) + result3 = draw_boxes_on_image(img_with_boxes3, bboxes) + + # Verify bounding box by checking for green pixel + assert np.all(result[10, 10] == np.array([0, 255, 0])) + assert np.all(result[30, 30] == np.array([0, 255, 0])) + assert np.all(result2[10, 10] == np.array([0, 255, 0])) + assert np.all(result2[30, 30] == np.array([0, 255, 0])) + assert np.all(result3[10, 10] == np.array([0, 255, 0])) + assert np.all(result3[30, 30] == np.array([0, 255, 0])) + + # Verify that no bounding box is drawn outside expected area + assert np.all(result[0:10, 0:10] == np.array([0, 0, 0])) + assert np.all(result[160:, 160:] == np.array([0, 0, 0])) + assert np.all(result2[0:10, 0:10] == np.array([0, 0, 0])) + assert np.all(result2[160:, 160:] == np.array([0, 0, 0])) + assert np.all(result3[0:10, 0:10] == np.array([0, 0, 0])) + assert np.all(result3[160:, 160:] == np.array([0, 0, 0])) + + # Verifying some pixels in the expected text region aren't pure black (text is greenish) + assert not np.all(result[30:40, 10:30] == np.array([0, 0, 0])) + assert not np.all(result[70:80, 30:50] == np.array([0, 0, 0])) + assert not np.all(result2[30:40, 10:30] == np.array([0, 0, 0])) + assert not np.all(result2[70:80, 30:50] == np.array([0, 0, 0])) + assert not np.all(result3[30:40, 10:30] == np.array([0, 0, 0])) + assert not np.all(result3[70:80, 30:50] == np.array([0, 0, 0])) + + # Verify that it returns an image + assert isinstance(result, np.ndarray) + assert result.shape == img_with_boxes.shape + assert isinstance(result2, np.ndarray) + assert result2.shape == img_with_boxes2.shape + assert isinstance(result3, np.ndarray) + assert result3.shape == img_with_boxes3.shape + + +def test_draw_largest_component_bbox(): + # Create a sample image and bounding boxes + img_with_boxes = np.zeros((200, 200, 3), dtype=np.uint8) + img_with_boxes2 = np.zeros((256, 256, 3), dtype=np.uint8) + img_with_boxes3 = np.zeros((512, 512, 3), dtype=np.uint8) + x, y, w, h = 10, 10, 20, 20 + + result = draw_largest_component_bbox_on_image(img_with_boxes, x, y, w, h) + result2 = draw_largest_component_bbox_on_image(img_with_boxes2, x, y, w, h) + result3 = draw_largest_component_bbox_on_image(img_with_boxes3, x, y, w, h) + + # Verify that box is drawn as expected + assert np.all(result[10, 10] == np.array([0, 0, 255])) # Red Box Pixel + assert np.all(result2[10, 10] == np.array([0, 0, 255])) + assert np.all(result3[10, 10] == np.array([0, 0, 255])) + + # Check that no bounding box is outside expected area + assert np.all(result[0:10, 0:10] == np.array([0, 0, 0])) + assert np.all(result2[0:10, 0:10] == np.array([0, 0, 0])) + assert np.all(result3[0:10, 0:10] == np.array([0, 0, 0])) + assert np.any(result[40:, 40:] != np.array([0, 0, 0])) + assert np.any(result2[40:, 40:] != np.array([0, 0, 0])) + assert np.any(result3[40:, 40:] != np.array([0, 0, 0])) + + # Verify some pixels in the expected text region aren't pure black (text is reddish) + assert not np.all(result[31:40, 10:30] == np.array([0, 0, 0])) + assert not np.all(result2[31:40, 10:30] == np.array([0, 0, 0])) + assert not np.all(result3[31:40, 10:30] == np.array([0, 0, 0])) + + # Verify that it returns an image + assert isinstance(result, np.ndarray) + assert result.shape == img_with_boxes.shape + assert isinstance(result2, np.ndarray) + assert result2.shape == img_with_boxes2.shape + assert isinstance(result3, np.ndarray) + assert result3.shape == img_with_boxes3.shape + + +def test_show_cam_on_image(): + # Create a sample image and CAM + img = np.zeros((200, 200, 3), dtype=np.uint8) + img2 = np.zeros((256, 256, 3), dtype=np.uint8) + img3 = np.zeros((512, 512, 3), dtype=np.uint8) + cam = np.ones((200, 200), dtype=np.float32) + cam2 = np.ones((256, 256), dtype=np.float32) + cam3 = np.ones((512, 512), dtype=np.float32) + + result = show_cam_on_image(img, cam) + result2 = show_cam_on_image(img2, cam2) + result3 = show_cam_on_image(img3, cam3) + + assert isinstance(result, np.ndarray) + assert result.dtype == np.uint8 + assert result.shape == (200, 200, 3) + assert isinstance(result2, np.ndarray) + assert result2.dtype == np.uint8 + assert result2.shape == (256, 256, 3) + assert isinstance(result3, np.ndarray) + assert result3.dtype == np.uint8 + assert result3.shape == (512, 512, 3) + + assert np.all(result >= 0) + assert np.all(result <= 255) + + +def test_show_cam_on_image_use_rgb(): + img = np.ones((2, 2, 3), dtype=np.float32) + cam = np.ones((2, 2), dtype=np.float32) * 0.5 + + result = show_cam_on_image(img, cam, use_rgb=True) + + # Assert that result and img are not the same (that color mapping was applied) + assert not np.array_equal(result, np.uint8(img * 255)) + + +def test_show_cam_on_image_colormap(): + img = np.ones((2, 2, 3), dtype=np.float32) + cam = np.ones((2, 2), dtype=np.float32) * 0.5 + + result = show_cam_on_image(img, cam, colormap=cv2.COLORMAP_JET) + + # Assert that result and img are not the same (that color mapping was applied) + assert not np.array_equal(result, np.uint8(img * 255)) + + +def test_show_cam_on_image_raises_for_invalid_img(): + img = np.ones((200, 200, 3), dtype=np.float32) * 2 # values are above 1 + cam = np.ones((200, 200), dtype=np.float32) + + with pytest.raises( + Exception, + match=r"The input image should np.float32 in the range \[0, 1\].*", + ): + show_cam_on_image(img, cam) + + +def test_show_cam_on_image_raises_for_invalid_image_weight(): + img = np.zeros((200, 200, 3), dtype=np.float32) + cam = np.ones((200, 200), dtype=np.float32) + + with pytest.raises( + Exception, match=r"image_weight should be in the range \[0, 1\].*" + ): + show_cam_on_image(img, cam, image_weight=1.5) + + +def test_show_cam_on_image_raises_for_mismatched_shapes(): + img = np.ones((200, 200, 3), dtype=np.float32) + cam = np.ones((100, 100), dtype=np.float32) + + with pytest.raises( + ValueError, + match="The shape of the mask should be the same as the shape of the image.", + ): + show_cam_on_image(img, cam) + + +def test_visualize_road_scores(): + # Create a sample image and CAM + img = np.zeros((200, 200, 3), dtype=np.uint8) + MoRF_score = 0.12345 + LeRF_score = 0.23456 + combiend_score = 0.34567 + name = "Road" + percentiles = [25, 75] + + result = visualize_road_scores( + img, MoRF_score, LeRF_score, combiend_score, name, percentiles + ) + + assert isinstance(result, np.ndarray) + assert np.array_equal(result, img) diff --git a/tests/test_cli.py b/tests/test_cli.py index 5799e332172e32d27d03ff0e4f3efdc28135e421..6d1bb3d1976d0a2955480518a8ff1021584bffdc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -148,12 +148,30 @@ def test_predict_help(): _check_help(predict) +def test_visualize_help(): + from ptbench.scripts.visualize import visualize + + _check_help(visualize) + + def test_evaluate_help(): from ptbench.scripts.evaluate import evaluate _check_help(evaluate) +def test_evaluatevis_help(): + from ptbench.scripts.evaluatevis import evaluatevis + + _check_help(evaluatevis) + + +def test_comparevis_help(): + from ptbench.scripts.comparevis import comparevis + + _check_help(comparevis) + + @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery(temporary_basedir): from ptbench.scripts.train import train @@ -397,6 +415,77 @@ def test_evaluate_pasa_montgomery(temporary_basedir): ) +def test_comparevis(temporary_basedir, datadir): + from ptbench.scripts.comparevis import comparevis + + runner = CliRunner() + + input_dir = str( + datadir / "test_visualization_images" / "indirect-model" / "tbx11k" + ) + output_dir = str(temporary_basedir / "comparevis") + + result = runner.invoke( + comparevis, ["-vv", "-i", str(input_dir), "-o", str(output_dir)] + ) + + assert result.exit_code == 0 + + # Check that the output image was created + assert ( + temporary_basedir + / "comparevis" + / "targeted_class" + / "test" + / "tb0004.png" + ).exists() + assert ( + temporary_basedir + / "comparevis" + / "targeted_class" + / "train" + / "tb0005.png" + ).exists() + + +def test_evaluatevis(temporary_basedir): + import pandas as pd + + from ptbench.scripts.evaluatevis import evaluatevis + + runner = CliRunner() + + # Create a sample directory structure and CSV files + input_folder = temporary_basedir / "camutils_cli" / "gradcam" + input_folder.mkdir(parents=True, exist_ok=True) + class1_dir = input_folder / "class1" + class1_dir.mkdir(parents=True, exist_ok=True) + class2_dir = input_folder / "class2" + class2_dir.mkdir(parents=True, exist_ok=True) + + data = { + "MoRF": [1, 2, 3], + "LeRF": [2, 4, 6], + "Combined Score ((LeRF-MoRF) / 2)": [1.5, 3, 4.5], + "IoU": [1, 2, 3], + "IoDA": [2, 4, 6], + "propEnergy": [1.5, 3, 4.5], + "ASF": [1, 2, 3], + } + df = pd.DataFrame(data) + df.to_csv(class1_dir / "file1.csv", index=False) + df.to_csv(class2_dir / "file1.csv", index=False) + df.to_csv(class1_dir / "file2.csv", index=False) + df.to_csv(class2_dir / "file2.csv", index=False) + + result = runner.invoke(evaluatevis, ["-vv", "-i", str(input_folder)]) + + assert result.exit_code == 0 + + assert (input_folder / "file1_summary.csv").exists() + assert (input_folder / "file2_summary.csv").exists() + + # Not enough RAM available to do this test # @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") # def test_predict_densenetrs_montgomery(temporary_basedir, datadir): diff --git a/tests/test_comparevis.py b/tests/test_comparevis.py new file mode 100644 index 0000000000000000000000000000000000000000..4540ea504136bfff541d0c6c7c93f9a39b02a255 --- /dev/null +++ b/tests/test_comparevis.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for the comparevis script.""" + +import shutil + +import pytest + +from click.testing import CliRunner + +from ptbench.scripts.comparevis import _sorting_rule, comparevis + + +@pytest.fixture() +def runner(): + return CliRunner() + + +@pytest.fixture +def comparevis_dirs(temporary_basedir): + input_dir = temporary_basedir / "comparevis_input_folder" + output_dir = temporary_basedir / "comparevis_output_folder" + if input_dir.exists(): + shutil.rmtree(input_dir) + if output_dir.exists(): + shutil.rmtree(output_dir) + input_dir.mkdir() + output_dir.mkdir() + return input_dir, output_dir + + +def test_sorting_rule(): + assert _sorting_rule("gradcam") == (0, "gradcam") + assert _sorting_rule("gradcam/scorecam") == (1, "scorecam") + assert _sorting_rule("fullgrad") == (2, "fullgrad") + assert _sorting_rule("randomcam") == (4, "randomcam") + assert _sorting_rule("test") == (3, "test") + + +def test_comparevis_output_subdirectory(runner, comparevis_dirs): + input_dir, output_dir = comparevis_dirs + sub_output_dir = input_dir / "subfolder" + sub_output_dir.mkdir() + + result = runner.invoke( + comparevis, ["-vv", "-i", str(input_dir), "-o", str(sub_output_dir)] + ) + assert isinstance(result.exception, AssertionError) + + +def test_comparevis_same_input_output(runner, comparevis_dirs): + input_dir, output_dir = comparevis_dirs + result = runner.invoke( + comparevis, ["-vv", "-i", str(input_dir), "-o", str(input_dir)] + ) + assert isinstance(result.exception, AssertionError) + + +def test_comparevis_no_subdirectories(runner, comparevis_dirs): + input_dir, output_dir = comparevis_dirs + result = runner.invoke( + comparevis, ["-vv", "-i", str(input_dir), "-o", str(output_dir)] + ) + assert isinstance(result.exception, ValueError) diff --git a/tests/test_visualizer.py b/tests/test_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1341df98740a91a2858e0c9098714a9a7b5c7828 --- /dev/null +++ b/tests/test_visualizer.py @@ -0,0 +1,171 @@ +import numpy as np +import pytest + +from ptbench.engine.visualizer import ( + _compute_avg_saliency_focus, + _compute_max_iou_and_ioda, + _compute_proportional_energy, + _compute_simultaneous_iou_and_ioda, + calculate_localization_metrics, +) + + +def test_compute_max_iou_and_ioda(): + detected_box = (10, 10, 100, 100) + gt_box_dict = {"xmin": 50, "ymin": 50, "width": 50, "height": 50} + gt_box_dict2 = {"xmin": 20, "ymin": 20, "width": 60, "height": 60} + gt_boxes = [[str(gt_box_dict)], [str(gt_box_dict2)]] + + iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) + + expected_iou = 0.36 + expected_ioda = 0.36 + + assert iou == expected_iou + assert ioda == expected_ioda + + +def test_compute_max_iou_and_ioda_zero_detected_area(): + detected_box = (10, 10, 0, 0) + gt_box_dict = {"xmin": 50, "ymin": 50, "width": 50, "height": 50} + gt_boxes = [[str(gt_box_dict)]] + + iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) + + # Should be zero as the detected box has no area + assert iou == 0 + assert ioda == 0 + + +def test_compute_max_iou_and_ioda_zero_gt_area(): + detected_box = (10, 10, 100, 100) + gt_box_dict = {"xmin": 50, "ymin": 50, "width": 0, "height": 0} + gt_boxes = [[str(gt_box_dict)]] + + iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) + + # Should be zero as there is no ground truth box + assert iou == 0 + assert ioda == 0 + + +def test_compute_max_iou_and_ioda_zero_intersection(): + detected_box = (10, 10, 100, 100) + gt_box_dict = {"xmin": 0, "ymin": 0, "width": 5, "height": 5} + gt_boxes = [[str(gt_box_dict)]] + + iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) + + assert iou == 0 + assert ioda == 0 + + +def test_compute_simultaneous_iou_and_ioda(): + detected_box = (10, 10, 100, 100) + gt_box_dict1 = {"xmin": 50, "ymin": 50, "width": 50, "height": 50} + gt_box_dict2 = {"xmin": 70, "ymin": 70, "width": 30, "height": 30} + gt_boxes = [[str(gt_box_dict1)], [str(gt_box_dict2)]] + + iou, ioda = _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes) + + assert iou == 0.34 + assert ioda == 0.34 + + +def test_compute_avg_saliency_focus(): + grayscale_cams = [np.ones((200, 200))] + grayscale_cams2 = [np.full((512, 512), 0.5)] + grayscale_cams3 = [np.zeros((256, 256))] + grayscale_cams3[0][50:75, 50:100] = 1 + gt_box_dict = {"xmin": 50, "ymin": 50, "width": 50, "height": 50} + gt_boxes = [[str(gt_box_dict)]] + + avg_saliency_focus = _compute_avg_saliency_focus(gt_boxes, grayscale_cams) + avg_saliency_focus2 = _compute_avg_saliency_focus(gt_boxes, grayscale_cams2) + avg_saliency_focus3 = _compute_avg_saliency_focus(gt_boxes, grayscale_cams3) + + assert avg_saliency_focus == 1 + assert avg_saliency_focus2 == 0.5 + assert avg_saliency_focus3 == 0.5 + + +def test_compute_avg_saliency_focus_no_activations(): + grayscale_cams = [np.zeros((200, 200))] + gt_box_dict = {"xmin": 50, "ymin": 50, "width": 50, "height": 50} + gt_boxes = [[str(gt_box_dict)]] + + avg_saliency_focus = _compute_avg_saliency_focus(gt_boxes, grayscale_cams) + + assert avg_saliency_focus == 0 + + +def test_compute_avg_saliency_focus_zero_gt_area(): + grayscale_cams = [np.ones((200, 200))] + gt_box_dict = {"xmin": 50, "ymin": 50, "width": 0, "height": 0} + gt_boxes = [[str(gt_box_dict)]] + + avg_saliency_focus = _compute_avg_saliency_focus(gt_boxes, grayscale_cams) + + assert avg_saliency_focus == 0 + + +def test_compute_proportional_energy(): + grayscale_cams = [np.ones((200, 200))] + grayscale_cams2 = [np.full((512, 512), 0.5)] + grayscale_cams3 = [np.zeros((512, 512))] + grayscale_cams3[0][100:200, 100:200] = 1 + gt_box_dict = {"xmin": 50, "ymin": 50, "width": 100, "height": 100} + gt_boxes = [[str(gt_box_dict)]] + + proportional_energy = _compute_proportional_energy(gt_boxes, grayscale_cams) + proportional_energy2 = _compute_proportional_energy( + gt_boxes, grayscale_cams2 + ) + proportional_energy3 = _compute_proportional_energy( + gt_boxes, grayscale_cams3 + ) + + assert proportional_energy == 0.25 + assert proportional_energy2 == 0.03814697265625 + assert proportional_energy3 == 0.25 + + +def test_compute_proportional_energy_no_activations(): + grayscale_cams = [np.zeros((200, 200))] + gt_box_dict = {"xmin": 50, "ymin": 50, "width": 50, "height": 50} + gt_boxes = [[str(gt_box_dict)]] + + proportional_energy = _compute_proportional_energy(gt_boxes, grayscale_cams) + + assert proportional_energy == 0 + + +def test_compute_proportional_energy_no_gt_box(): + grayscale_cams = [np.ones((200, 200))] + gt_box_dict = {"xmin": 0, "ymin": 0, "width": 0, "height": 0} + gt_boxes = [[str(gt_box_dict)]] + + proportional_energy = _compute_proportional_energy(gt_boxes, grayscale_cams) + + assert proportional_energy == 0 + + +def test_calculate_localization_metrics(): + grayscale_cams = [np.zeros((200, 200))] + detected_box = (10, 10, 100, 100) + gt_box_dict = {"xmin": 50, "ymin": 50, "width": 50, "height": 50} + ground_truth_box = [[str(gt_box_dict)]] + + ( + iou, + ioda, + proportional_energy, + avg_saliency_focus, + ) = calculate_localization_metrics( + grayscale_cams, detected_box, ground_truth_box + ) + + assert iou == 0.25 + assert ioda == 0.25 + assert proportional_energy == 0 + assert avg_saliency_focus == 0