From c3bebb065806751adeeac2fc60843786aa453d82 Mon Sep 17 00:00:00 2001
From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch>
Date: Wed, 20 Sep 2023 01:01:39 +0200
Subject: [PATCH] added saliency map evaluation and visualisation for lightning
 models

---
 src/ptbench/engine/saliencymap_evaluator.py  | 41 +++++----
 src/ptbench/engine/visualizer.py             | 70 ++++++++-------
 src/ptbench/scripts/calculate_road.py        | 12 ++-
 src/ptbench/scripts/evaluate_saliencymaps.py | 64 +++++++-------
 src/ptbench/scripts/generate_saliencymaps.py | 10 +--
 src/ptbench/scripts/visualize.py             | 89 +++++++++++---------
 src/ptbench/utils/cam_utils.py               |  7 +-
 7 files changed, 150 insertions(+), 143 deletions(-)

diff --git a/src/ptbench/engine/saliencymap_evaluator.py b/src/ptbench/engine/saliencymap_evaluator.py
index 1285dc95..07d9dde8 100644
--- a/src/ptbench/engine/saliencymap_evaluator.py
+++ b/src/ptbench/engine/saliencymap_evaluator.py
@@ -2,7 +2,6 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-import ast
 import logging
 import os
 
@@ -29,10 +28,8 @@ def _compute_max_iou_and_ioda(detected_box, gt_boxes):
     max_gt_area = 0
 
     for bbox in gt_boxes:
-        unpacked_bbox = bbox[0]
-        bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"'))
-        xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"])
-        width, height = int(bbox_dict["width"]), int(bbox_dict["height"])
+        xmin, ymin = int(bbox[1].item()), int(bbox[2].item())
+        width, height = int(bbox[3].item()), int(bbox[4].item())
 
         gt_area = width * height
 
@@ -78,10 +75,8 @@ def _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes):
     total_gt_area = 0
 
     for bbox in gt_boxes:
-        unpacked_bbox = bbox[0]
-        bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"'))
-        xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"])
-        width, height = int(bbox_dict["width"]), int(bbox_dict["height"])
+        xmin, ymin = int(bbox[1].item()), int(bbox[2].item())
+        width, height = int(bbox[3].item()), int(bbox[4].item())
 
         gt_area = width * height
         total_gt_area += gt_area
@@ -112,10 +107,8 @@ def _compute_avg_saliency_focus(gt_boxes, saliency_map):
     # For each gt box, draw a binary mask
     # The binary_mask will be 1 where the gt boxes are located
     for bbox in gt_boxes:
-        unpacked_bbox = bbox[0]
-        bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"'))
-        xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"])
-        width, height = int(bbox_dict["width"]), int(bbox_dict["height"])
+        xmin, ymin = int(bbox[1].item()), int(bbox[2].item())
+        width, height = int(bbox[3].item()), int(bbox[4].item())
 
         binary_mask[ymin : ymin + height, xmin : xmin + width] = 1
 
@@ -143,10 +136,8 @@ def _compute_proportional_energy(gt_boxes, saliency_map):
     # For each gt box, draw a binary mask
     # The binary_mask will be 1 where the gt boxes are located
     for bbox in gt_boxes:
-        unpacked_bbox = bbox[0]
-        bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"'))
-        xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"])
-        width, height = int(bbox_dict["width"]), int(bbox_dict["height"])
+        xmin, ymin = int(bbox[1].item()), int(bbox[2].item())
+        width, height = int(bbox[3].item()), int(bbox[4].item())
 
         binary_mask[ymin : ymin + height, xmin : xmin + width] = 1
 
@@ -223,7 +214,7 @@ def process_target_class(
     # Write metrics to csv file
     csv_writer.writerow(
         [
-            names[0].split("/")[1],
+            names[0],
             iou,
             ioda,
             proportional_energy,
@@ -268,19 +259,25 @@ def run(
     """
     for samples in tqdm(data_loader, desc="batches", leave=False, disable=None):
         # Check if the sample has a bounding box entry
-        if len(samples) < 4:
+        if "radsign_bboxes" not in samples[1]:
             logger.warning(
                 "The dataset does not contain bounding box information. No localization metrics can be calculated."
             )
             return
         else:
             # TB negative labels are skipped (they don't have gt bboxes)
-            if samples[3][0] == "none":
+            if samples[1]["label"].item() == 0:
                 continue
 
-        names = samples[0]
+        names = samples[1]["name"]
 
-        gt_bboxes = samples[3]
+        gt_bboxes = samples[1]["radsign_bboxes"]
+
+        if not gt_bboxes:
+            logger.warning(
+                f'This sample does not have bounding box information. No localization metrics can be calculated. Sample "{names[0]}" is skipped.'
+            )
+            continue
 
         for target_class_name, csv_writer in csv_writers.items():
             saliency_map_path = os.path.join(
diff --git a/src/ptbench/engine/visualizer.py b/src/ptbench/engine/visualizer.py
index b734a547..b149f3ea 100644
--- a/src/ptbench/engine/visualizer.py
+++ b/src/ptbench/engine/visualizer.py
@@ -78,8 +78,8 @@ def run(
     ), "Output folder must not be the same as the input folder."
 
     # Check that the output folder is not a subdirectory of the input_folder
-    assert not output_folder.startswith(
-        input_folder
+    assert not str(output_folder).startswith(
+        str(input_folder)
     ), "Output folder must not be a subdirectory of the input folder."
 
     output_folder = os.path.abspath(output_folder)
@@ -89,6 +89,8 @@ def run(
 
     target_classes = _get_target_classes_from_directory(input_folder)
 
+    shown_warnings = set()
+
     if visualize_detected_bbox:
         # Load detected bounding box information
         localization_metric_dfs = {}
@@ -113,16 +115,19 @@ def run(
 
     for samples in tqdm(data_loader, desc="batches", leave=False, disable=None):
         includes_bboxes = True
-        if len(samples) < 4 or samples[2][0].item() == 0:
+        if samples[1]["label"].item() == 0 or "radsign_bboxes" not in samples[1]:
             includes_bboxes = False
+        else:
+            gt_bboxes = samples[1]["radsign_bboxes"]
+            if not gt_bboxes:
+                includes_bboxes = False
 
-        # if visualize_groundtruth and not includes_bboxes:
-        #    logger.warning(
-        #        "This sample does not have bounding box information. No ground truth bounding boxes can be visualized."
-        #    )
+        names = samples[1]["name"]
 
-        names = samples[0]
-        gt_bboxes = samples[3]
+        if visualize_groundtruth and not includes_bboxes and samples[1]["label"].item() == 1:
+           logger.warning(
+               f'Sample "{names[0]}" does not have bounding box information. No ground truth bounding boxes can be visualized.'
+           )
 
         for target_class in target_classes:
             input_base_path = (
@@ -136,7 +141,7 @@ def run(
             saliency_map_path = os.path.join(input_base_path, saliency_map_name)
 
             # Skip if the saliency map does not exist (e.g. if no saliency maps were generated for TB negative samples)
-            if not includes_bboxes and not os.path.exists(saliency_map_path):
+            if not os.path.exists(saliency_map_path):
                 continue
 
             saliency_map = np.load(saliency_map_path)
@@ -146,11 +151,10 @@ def run(
             img_with_boxes = original_img.copy()
 
             # Draw bounding boxes on the image
-            if includes_bboxes:
-                if gt_bboxes != "none" and visualize_groundtruth:
-                    img_with_boxes = draw_boxes_on_image(
-                        img_with_boxes, gt_bboxes
-                    )
+            if visualize_groundtruth and includes_bboxes:
+                img_with_boxes = draw_boxes_on_image(
+                    img_with_boxes, gt_bboxes
+                )
 
             # show_cam_on_image expects the img to be between [0, 1]
             rgb_img = img_with_boxes / 255.0
@@ -168,7 +172,7 @@ def run(
             if visualize_detected_bbox:
                 if target_class in localization_metric_dfs:
                     temp_df = localization_metric_dfs[target_class]
-                    temp_filename = names[0].split("/")[1]
+                    temp_filename = names[0]
 
                     row = temp_df[temp_df["Image"] == temp_filename]
 
@@ -186,15 +190,16 @@ def run(
                             f"Could not find entry for {temp_filename} in .csv file."
                         )
                 else:
-                    logger.warning(
-                        f"No localization metrics .csv file found for the {dataset_split_name} split for {target_class}."
-                    )
+                    error_message = f"No localization metrics .csv file found for the {dataset_split_name} split for {target_class}."
+                    if error_message not in shown_warnings:
+                        logger.warning(error_message)
+                        shown_warnings.add(error_message)
 
             # Visualize ROAD scores
             if road_path is not None:
                 if target_class in road_metric_dfs:
                     temp_df = road_metric_dfs[target_class]
-                    temp_filename = names[0].split("/")[1]
+                    temp_filename = names[0]
 
                     row = temp_df[temp_df["Image"] == temp_filename]
 
@@ -205,23 +210,24 @@ def run(
                             row["Combined Score ((LeRF-MoRF) / 2)"].values[0]
                         )
                         percentiles = str(row["Percentiles"].values[0])
+
+                        visualization = visualize_road_scores(
+                            visualization,
+                            MoRF_score,
+                            LeRF_score,
+                            combined_score,
+                            vis_directory_name,
+                            percentiles,
+                        )
                     else:
                         logger.warning(
                             f"Could not find entry for {temp_filename} in .csv file."
                         )
                 else:
-                    logger.warning(
-                        f"No ROAD metrics .csv file found for the {dataset_split_name} split for {target_class}."
-                    )
-
-                visualization = visualize_road_scores(
-                    visualization,
-                    MoRF_score,
-                    LeRF_score,
-                    combined_score,
-                    vis_directory_name,
-                    percentiles,
-                )
+                    error_message = f"No ROAD metrics .csv file found for the {dataset_split_name} split for {target_class}."
+                    if error_message not in shown_warnings:
+                        logger.warning(error_message)
+                        shown_warnings.add(error_message)
 
             # Save image
             output_file_path = os.path.join(output_base_path, names[0])
diff --git a/src/ptbench/scripts/calculate_road.py b/src/ptbench/scripts/calculate_road.py
index 41bbe3db..708e7254 100644
--- a/src/ptbench/scripts/calculate_road.py
+++ b/src/ptbench/scripts/calculate_road.py
@@ -194,28 +194,28 @@ def prepare_csv_writers(
 
        .. code:: sh
 
-          ptbench calculate-road -vv pasa tbx11k_simplified_bbox --device="cuda" --weight=path/to/model_final.pth --output-folder=path/to/output_folder
+          ptbench calculate-road -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model_final.pth --output-folder=path/to/output_folder
 
 """,
 )
 @click.option(
     "--model",
     "-m",
-    help="A lightining module instance implementing the network to be trained.",
+    help="A lightining module instance implementing the network to be used for inference.",
     required=True,
     cls=ResourceOption,
 )
 @click.option(
     "--datamodule",
     "-d",
-    help="A lighting data module containing the training and validation sets.",
+    help="A lighting data module containing the training, validation and test sets.",
     required=True,
     cls=ResourceOption,
 )
 @click.option(
     "--output-folder",
     "-o",
-    help="Path where to store the visualizations (created if does not exist)",
+    help="Path where to store the ROAD scores (created if does not exist)",
     required=True,
     type=click.Path(
         file_okay=False,
@@ -352,11 +352,9 @@ def calculate_road(
             
             logger.info(f"Calculating ROAD scores for '{k}' set...")
 
-            data_loader = v
-
             run(
                 model,
-                data_loader,
+                data_loader=v,
                 output_folder=output_folder,
                 device=device,
                 cam=cam,
diff --git a/src/ptbench/scripts/evaluate_saliencymaps.py b/src/ptbench/scripts/evaluate_saliencymaps.py
index 991e9e5e..e9cc72a9 100644
--- a/src/ptbench/scripts/evaluate_saliencymaps.py
+++ b/src/ptbench/scripts/evaluate_saliencymaps.py
@@ -4,6 +4,7 @@
 
 import csv
 import os
+import pathlib
 
 import click
 
@@ -65,18 +66,21 @@ def prepare_csv_writers(input_folder, dataset_split):
 
        .. code:: sh
 
-          ptbench evaluate-saliencymaps -vv tbx11k_simplified_bbox_rgb --input-folder=parent_folder/gradcam/
+          ptbench evaluate-saliencymaps -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/
 
 """,
 )
 @click.option(
-    "--dataset",
+    "--model",
+    "-m",
+    help="A lightining module instance implementing the network to be used for applying the necessary data transformations.",
+    required=True,
+    cls=ResourceOption,
+)
+@click.option(
+    "--datamodule",
     "-d",
-    help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
-    "to be used for generating visualizations, possibly including all pre-processing "
-    "pipelines required or, optionally, a dictionary mapping string keys to "
-    "torch.utils.data.dataset.Dataset instances.  All keys that do not start "
-    "with an underscore (_) will be processed.",
+    help="A lighting data module containing the training, validation and test sets.",
     required=True,
     cls=ResourceOption,
 )
@@ -85,13 +89,19 @@ def prepare_csv_writers(input_folder, dataset_split):
     "-i",
     help="Path to the folder containing the saliency maps for a specific visualization type.",
     required=True,
+        type=click.Path(
+        file_okay=False,
+        dir_okay=True,
+        writable=True,
+        path_type=pathlib.Path,
+    ),
     default="visualizations",
     cls=ResourceOption,
-    type=click.Path(),
 )
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def evaluate_saliencymaps(
-    dataset,
+    model,
+    datamodule,
     input_folder,
     **_,
 ) -> None:
@@ -102,40 +112,30 @@ def evaluate_saliencymaps(
     Calculates them for each target class and split of the dataset.
     """
 
-    import torch
+    from .utils import save_sh_command
+    from ..engine.saliencymap_evaluator import run
 
-    from torch.utils.data import DataLoader
+    save_sh_command(input_folder / "command.sh")
 
-    from ..engine.saliencymap_evaluator import run
+    datamodule.set_chunk_size(1, 1)
+    datamodule.drop_incomplete_batch = False
+    #datamodule.cache_samples = cache_samples
+    #datamodule.parallel = parallel
+    datamodule.model_transforms = model.model_transforms
 
-    if "datadir" in dataset:
-        dataset = (
-            dataset["dataset"]
-            if isinstance(dataset["dataset"], dict)
-            else dict(test=dataset["dataset"])
-        )
-    else:
-        dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
+    datamodule.prepare_data()
+    datamodule.setup(stage="predict")
 
-    for k, v in dataset.items():
-        if k.startswith("_"):
-            logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
-            continue
+    dataloaders = datamodule.predict_dataloader()
 
+    for k, v in dataloaders.items():
         csv_files, csv_writers = prepare_csv_writers(input_folder, k)
 
         logger.info(f"Calculating localization metrics for '{k}' set...")
 
-        data_loader = DataLoader(
-            dataset=v,
-            batch_size=1,
-            shuffle=False,
-            pin_memory=torch.cuda.is_available(),
-        )
-
         run(
             input_folder,
-            data_loader,
+            data_loader=v,
             dataset_split_name=k,
             csv_writers=csv_writers,
         )
diff --git a/src/ptbench/scripts/generate_saliencymaps.py b/src/ptbench/scripts/generate_saliencymaps.py
index 331e31db..411ac0d7 100644
--- a/src/ptbench/scripts/generate_saliencymaps.py
+++ b/src/ptbench/scripts/generate_saliencymaps.py
@@ -126,21 +126,21 @@ def create_cam(vis_type, model, target_layers, use_cuda):
 
        .. code:: sh
 
-          ptbench generate-saliencymaps -vv densenet tbx11k_simplified_bbox_rgb --device="cuda" --weight=path/to/model_final.pth --output-folder=path/to/visualizations
+          ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model_final.pth --output-folder=path/to/visualizations
 
 """,
 )
 @click.option(
     "--model",
     "-m",
-    help="A lightining module instance implementing the network to be trained.",
+    help="A lightining module instance implementing the network to be used for inference.",
     required=True,
     cls=ResourceOption,
 )
 @click.option(
     "--datamodule",
     "-d",
-    help="A lighting data module containing the training and validation sets.",
+    help="A lighting data module containing the training, validation and test sets.",
     required=True,
     cls=ResourceOption,
 )
@@ -273,11 +273,9 @@ def generate_saliencymaps(
         for k, v in dataloaders.items():
             logger.info(f"Generating saliency maps for '{k}' set...")
 
-            data_loader = v
-
             run(
                 model,
-                data_loader,
+                data_loader=v,
                 output_folder=output_folder,
                 dataset_split_name=k,
                 device=device,
diff --git a/src/ptbench/scripts/visualize.py b/src/ptbench/scripts/visualize.py
index b3922581..96bf4b04 100644
--- a/src/ptbench/scripts/visualize.py
+++ b/src/ptbench/scripts/visualize.py
@@ -2,6 +2,8 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import pathlib
+
 import click
 
 from clapper.click import ConfigCommand, ResourceOption, verbosity_option
@@ -20,18 +22,21 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
        .. code:: sh
 
-          ptbench visualize -vv tbx11k_simplified_bbox_rgb --input-folder=parent_folder/gradcam/ --road-path=parent_folder/gradcam --output-folder=path/to/visualizations
+          ptbench visualize -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --road-path=parent_folder/gradcam/ --output-folder=path/to/visualizations
 
 """,
 )
 @click.option(
-    "--dataset",
+    "--model",
+    "-m",
+    help="A lightining module instance implementing the network to be used for applying the necessary data transformations.",
+    required=True,
+    cls=ResourceOption,
+)
+@click.option(
+    "--datamodule",
     "-d",
-    help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
-    "to be used for generating visualizations, possibly including all pre-processing "
-    "pipelines required or, optionally, a dictionary mapping string keys to "
-    "torch.utils.data.dataset.Dataset instances.  All keys that do not start "
-    "with an underscore (_) will be processed.",
+    help="A lighting data module containing the training, validation and test sets.",
     required=True,
     cls=ResourceOption,
 )
@@ -40,18 +45,28 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     "-i",
     help="Path to the folder containing the saliency maps for a specific visualization type.",
     required=True,
+        type=click.Path(
+        file_okay=False,
+        dir_okay=True,
+        writable=True,
+        path_type=pathlib.Path,
+    ),
     default="visualizations",
     cls=ResourceOption,
-    type=click.Path(),
 )
 @click.option(
     "--output-folder",
     "-o",
-    help="Path where to store the visualizations (created if does not exist)",
+    help="Path where to store the ROAD scores (created if does not exist)",
     required=True,
+    type=click.Path(
+        file_okay=False,
+        dir_okay=True,
+        writable=True,
+        path_type=pathlib.Path,
+    ),
     default="visualizations",
     cls=ResourceOption,
-    type=click.Path(),
 )
 @click.option(
     "--visualize-groundtruth",
@@ -64,7 +79,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 @click.option(
     "--road-path",
     "-r",
-    help="If the path to the previously calculated road scores is provided, MoRF, LeRF, and combined ROAD scores will be visualized on each image.",
+    help="If the path to the previously calculated ROAD scores is provided, MoRF, LeRF, and combined ROAD scores will be visualized on each image.",
     required=False,
     default=None,
     cls=ResourceOption,
@@ -90,7 +105,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 )
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def visualize(
-    dataset,
+    model,
+    datamodule,
     input_folder,
     output_folder,
     visualize_groundtruth,
@@ -101,34 +117,29 @@ def visualize(
 ) -> None:
     """Generates heatmaps for input CXRs based on existing saliency maps."""
 
-    import torch
+    from .utils import save_sh_command
+    from ..engine.visualizer import run
 
-    from torch.utils.data import DataLoader
+    save_sh_command(input_folder / "command.sh")
 
-    from ..engine.visualizer import run
+    datamodule.set_chunk_size(1, 1)
+    datamodule.drop_incomplete_batch = False
+    #datamodule.cache_samples = cache_samples
+    #datamodule.parallel = parallel
+    datamodule.model_transforms = model.model_transforms
 
-    if "datadir" in dataset:
-        img_dir = dataset["datadir"]
-        dataset = (
-            dataset["dataset"]
-            if isinstance(dataset["dataset"], dict)
-            else dict(test=dataset["dataset"])
-        )
-
-        for k, v in dataset.items():
-            if k.startswith("_"):
-                logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
-                continue
-
-            data_loader = DataLoader(
-                dataset=v,
-                batch_size=1,
-                shuffle=False,
-                pin_memory=torch.cuda.is_available(),
-            )
+    datamodule.prepare_data()
+    datamodule.setup(stage="predict")
+
+    dataloaders = datamodule.predict_dataloader()
+
+    for k, v in dataloaders.items():
+        # Hacky way to get img_dir
+        img_dir = datamodule.splits[k][0][1].datadir
 
+        if img_dir:
             run(
-                data_loader,
+                data_loader=v,
                 img_dir=img_dir,
                 input_folder=input_folder,
                 output_folder=output_folder,
@@ -138,7 +149,7 @@ def visualize(
                 visualize_detected_bbox=visualize_detected_bbox,
                 threshold=threshold,
             )
-    else:
-        logger.warning(
-            'No "datadir" or "dataset_name" key found in dataset. No visualizations can be generated.'
-        )
+        else:
+            logger.warning(
+                'No "datadir" found in dataset. No visualizations can be generated.'
+            )
\ No newline at end of file
diff --git a/src/ptbench/utils/cam_utils.py b/src/ptbench/utils/cam_utils.py
index 4472f451..15699342 100644
--- a/src/ptbench/utils/cam_utils.py
+++ b/src/ptbench/utils/cam_utils.py
@@ -1,4 +1,3 @@
-import ast
 import glob
 import os
 import shutil
@@ -98,10 +97,8 @@ def calculate_metrics_avg_for_every_class(input_folder):
 
 def draw_boxes_on_image(img_with_boxes, bboxes):
     for i, bbox in enumerate(bboxes):
-        unpacked_bbox = bbox[0]
-        bbox_dict = ast.literal_eval(unpacked_bbox.replace("'", '"'))
-        xmin, ymin = int(bbox_dict["xmin"]), int(bbox_dict["ymin"])
-        width, height = bbox_dict["width"], bbox_dict["height"]
+        xmin, ymin = int(bbox[1].item()), int(bbox[2].item())
+        width, height = int(bbox[3].item()), int(bbox[4].item())
         xmax = int(xmin + width)
         ymax = int(ymin + height)
         cv2.rectangle(
-- 
GitLab