Skip to content
Snippets Groups Projects

Adds grad-cam support on classifiers

Merged André Anjos requested to merge add-datamodule-gradcam into main
Compare and
27 files
+ 3171
0
Compare changes
  • Side-by-side
  • Inline
Files
27
+ 253
0
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import os
import numpy as np
import torch
from pytorch_grad_cam.metrics.road import (
ROADCombined,
ROADLeastRelevantFirstAverage,
ROADMostRelevantFirstAverage,
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from tqdm import tqdm
logger = logging.getLogger(__name__)
class SigmoidClassifierOutputTarget:
def __init__(self, category):
self.category = category
def __call__(self, model_output):
sigmoid_output = torch.sigmoid(model_output)
if len(sigmoid_output.shape) == 1:
return sigmoid_output[self.category]
return sigmoid_output[:, self.category]
rs_maps = {
0: "cardiomegaly",
1: "emphysema",
2: "effusion",
3: "hernia",
4: "infiltration",
5: "mass",
6: "nodule",
7: "atelectasis",
8: "pneumothorax",
9: "pleural thickening",
10: "pneumonia",
11: "fibrosis",
12: "edema",
13: "consolidation",
}
def calculate_road_metrics(
input_image,
grayscale_cams,
model,
targets=None,
percentiles=[20, 40, 60, 80],
):
"""Calculates ROAD scores by averaging the scores for different percentiles
for a single input image for a given visualization method and a given
target class."""
cam_metric_ROADMoRF_avg = ROADMostRelevantFirstAverage(
percentiles=percentiles
)
cam_metric_ROADLeRF_avg = ROADLeastRelevantFirstAverage(
percentiles=percentiles
)
cam_metric_ROADCombined_avg = ROADCombined(percentiles=percentiles)
# Calculate ROAD scores for each percentile
MoRF_scores = cam_metric_ROADMoRF_avg(
input_tensor=input_image,
cams=grayscale_cams,
model=model,
targets=targets,
)
LeRF_scores = cam_metric_ROADLeRF_avg(
input_tensor=input_image,
cams=grayscale_cams,
model=model,
targets=targets,
)
combined_scores = cam_metric_ROADCombined_avg(
input_tensor=input_image,
cams=grayscale_cams,
model=model,
targets=targets,
)
return MoRF_scores, LeRF_scores, combined_scores
# Helper function to calculate the ROAD scores for a single target class
# of a single input image.
def process_target_class(
model,
names,
images,
targets,
metric_targets,
cam,
csv_writer,
percentiles,
):
grayscale_cams = cam(input_tensor=images, targets=targets)
MoRF_scores, LeRF_scores, combined_scores = calculate_road_metrics(
input_image=images,
grayscale_cams=grayscale_cams,
model=model,
targets=metric_targets,
percentiles=percentiles,
)
MoRF_score = MoRF_scores[0]
LeRF_score = LeRF_scores[0]
combined_score = combined_scores[0]
# Write metrics to csv file
csv_writer.writerow(
[
names[0],
MoRF_score,
LeRF_score,
combined_score,
str(percentiles),
]
)
def run(
model,
data_loader,
output_folder,
device,
cam,
csv_writers,
target_class="highest",
tb_positive_only=True,
):
"""Applies visualization techniques on input CXR, and perturbs them to
calculate ROAD scores.
Parameters
---------
model
Neural network model (e.g. pasa).
data_loader
The pytorch lightning Dataloader used to iterate over batches.
output_folder : str
Directory in which the results will be saved.
dataset_split_name : str
Name of the dataset split (e.g. "train", "validation", "test").
device : str
A string indicating the device to use (e.g. "cpu" or "cuda"). The device can also be specified (cuda:0)
cam : py:class: `pytorch_grad_cam.GradCAM`, `pytorch_grad_cam.ScoreCAM`,
`pytorch_grad_cam.FullGrad`, `pytorch_grad_cam.RandomCAM`,
`pytorch_grad_cam.EigenCAM`, `pytorch_grad_cam.EigenGradCAM`,
`pytorch_grad_cam.LayerCAM`, `pytorch_grad_cam.XGradCAM`,
`pytorch_grad_cam.AblationCAM`, `pytorch_grad_cam.HiResCAM`,
`pytorch_grad_cam.GradCAMElementWise`, `pytorch_grad_cam.GradCAMplusplus`,
The CAM object to use for visualization.
visualization_types : list
Type of visualization techniques to be applied. Possible values are:
"GradCAM", "ScoreCAM", "FullGrad", "RandomCAM", "HiResCAM", "GradCAMElementWise", "GradCAMPlusPlus", "XGradCAM", "AblationCAM",
"EigenCAM", "EigenGradCAM", "LayerCAM".
csv_writers : dict
Dictionary containing csv writer objects for each target class.
target_class : str
(Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is default, which means only visualizations for the class with the highest activation will be generated.
tb_positive_only : bool
If set, only TB positive samples will be visualized.
Returns
-------
all_road_scores : list
All the ROAD scores associated with filename, saved as .csv.
"""
output_folder = os.path.abspath(output_folder)
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
model_name = model.__class__.__name__
percentiles = [20, 40, 60, 80]
for samples in tqdm(data_loader, desc="batches", leave=False, disable=None):
# TB negative labels are skipped
if samples[1]["label"].item() == 0:
if tb_positive_only:
continue
names = samples[1]["name"]
images = samples[0].to(
device=device, non_blocking=torch.cuda.is_available()
)
if model_name == "DensenetRS" and target_class.lower() == "all":
for target in range(14):
targets = [ClassifierOutputTarget(target)]
metric_targets = [SigmoidClassifierOutputTarget(target)]
csv_writer = csv_writers[rs_maps[target]]
process_target_class(
model,
names,
images,
targets,
metric_targets,
cam,
csv_writer,
percentiles,
)
if model_name == "DensenetRS":
# Get the class with highest activation manually
outputs = cam.activations_and_grads(images)
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
targets = [
ClassifierOutputTarget(category)
for category in target_categories
]
metric_targets = [
SigmoidClassifierOutputTarget(category)
for category in target_categories
]
else:
targets = [ClassifierOutputTarget(0)]
metric_targets = [SigmoidClassifierOutputTarget(0)]
csv_writer = csv_writers["targeted_class"]
process_target_class(
model,
names,
images,
targets,
metric_targets,
cam,
csv_writer,
percentiles,
)
Loading