Skip to content
Snippets Groups Projects
Commit 2a7150d6 authored by André Anjos's avatar André Anjos :speech_balloon: Committed by Daniel CARRON
Browse files

[engine,scripts] Simplify saliencymap generator; Add more type hinting; Make...

[engine,scripts] Simplify saliencymap generator; Add more type hinting; Make it TB-agnostic; Closes #49
parent 6bd5719e
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
......@@ -3,149 +3,214 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import os
import pathlib
import typing
import numpy as np
import lightning.pytorch
import numpy
import torch
import torch.nn
import tqdm
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from tqdm import tqdm
from ..models.typing import VisualisationType
from .device import DeviceManager
logger = logging.getLogger(__name__)
rs_maps = {
0: "cardiomegaly",
1: "emphysema",
2: "effusion",
3: "hernia",
4: "infiltration",
5: "mass",
6: "nodule",
7: "atelectasis",
8: "pneumothorax",
9: "pleural thickening",
10: "pneumonia",
11: "fibrosis",
12: "edema",
13: "consolidation",
}
def _save_npy(img_name_stem, grayscale_cam, visualization_path):
image_path = f"{visualization_path}/{img_name_stem}"
os.makedirs(os.path.dirname(image_path), exist_ok=True)
np.save(image_path, grayscale_cam)
# Helper function to calculate saliency maps for a single target class
# of a single input image.
def process_target_class(names, images, targets, visualization_path, cam):
grayscale_cams = cam(input_tensor=images, targets=targets)
for i, grayscale_cam in enumerate(grayscale_cams):
img_name_stem = names[i].split(".")[0]
_save_npy(img_name_stem, grayscale_cam, visualization_path)
def _create_visualiser(
vis_type: VisualisationType,
model: torch.nn.Module,
target_layers: list[torch.nn.Module] | None,
use_cuda: bool,
):
"""Creates a class activation map (CAM) instance for a given model."""
import pytorch_grad_cam
match vis_type:
case "gradcam":
return pytorch_grad_cam.GradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "scorecam":
return pytorch_grad_cam.ScoreCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "fullgrad":
return pytorch_grad_cam.FullGrad(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "randomcam":
return pytorch_grad_cam.RandomCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "hirescam":
return pytorch_grad_cam.HiResCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "gradcamelementwise":
return pytorch_grad_cam.GradCAMElementWise(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "gradcam++", "gradcamplusplus":
return pytorch_grad_cam.GradCAMPlusPlus(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "xgradcam":
return pytorch_grad_cam.XGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "ablationcam":
assert (
target_layers is not None
), "AblationCAM cannot have target_layers=None"
return pytorch_grad_cam.AblationCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "eigencam":
return pytorch_grad_cam.EigenCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "eigengradcam":
return pytorch_grad_cam.EigenGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "layercam":
return pytorch_grad_cam.LayerCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case _:
raise ValueError(
f"Visualisation type `{vis_type}` is not currently supported."
)
def _save_visualisation(
output_folder: pathlib.Path, name: str, vis: torch.Tensor
) -> None:
"""Helper function to save a saliency map to disk."""
n = pathlib.Path(name)
(output_folder / n.parent).mkdir(parents=True, exist_ok=True)
numpy.save(output_folder / n.with_suffix(".npy"), vis)
def run(
model,
data_loader,
output_folder,
dataset_split_name,
device,
cam,
visualization_type,
target_class="highest",
tb_positive_only=True,
):
"""Applies visualization techniques on input CXR, outputs pickled saliency
maps.
model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager,
visualisation_types: typing.Sequence[VisualisationType],
target_class: typing.Literal["highest", "all"],
positive_only: bool,
output_folder: pathlib.Path,
) -> None:
"""Applies visualisation techniques on input CXR, outputs pickled saliency
maps directly to disk.
Parameters
---------
model
Neural network model (e.g. pasa).
data_loader
The pytorch lightning Dataloader used to iterate over batches.
output_folder : str
Directory in which the results will be saved.
dataset_split_name : str
Name of the dataset split (e.g. "train", "validation", "test").
device : str
A string indicating the device to use (e.g. "cpu" or "cuda"). The device can also be specified (cuda:0)
cam : py:class: `pytorch_grad_cam.GradCAM`, `pytorch_grad_cam.ScoreCAM`,
`pytorch_grad_cam.FullGrad`, `pytorch_grad_cam.RandomCAM`,
`pytorch_grad_cam.EigenCAM`, `pytorch_grad_cam.EigenGradCAM`,
`pytorch_grad_cam.LayerCAM`, `pytorch_grad_cam.XGradCAM`,
`pytorch_grad_cam.AblationCAM`, `pytorch_grad_cam.HiResCAM`,
`pytorch_grad_cam.GradCAMElementWise`, `pytorch_grad_cam.GradCAMplusplus`,
The CAM object to use for visualization.
visualization_types : list
Type of visualization techniques to be applied. Possible values are:
"GradCAM", "ScoreCAM", "FullGrad", "RandomCAM", "HiResCAM", "GradCAMElementWise", "GradCAMPlusPlus", "XGradCAM", "AblationCAM",
"EigenCAM", "EigenGradCAM", "LayerCAM".
target_class : str
(Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is default, which means only saliency maps for the class with the highest activation will be generated.
tb_positive_only : bool
If set, saliency maps will only be generated for TB positive samples.
Returns
-------
all_predictions : list
All the predictions associated with filename and ground truth, saved as .csv.
datamodule
The lightning datamodule to use for training **and** validation
device_manager
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a torch lightning accelerator setup.
visualisation_types
The types of visualisations to generate for the current model and
datamodule.
target_class
(Use only with multi-label models) Which class to target for CAM
calculation. Can be either set to "all" or "highest". "highest" is
default, which means only saliency maps for the class with the highest
activation will be generated.
positive_only
If set, saliency maps will only be generated for positive samples (ie.
label == 1 in a binary classification task). This option is ignored on
a multi-class output model.
output_folder
Where to save all the visualisations (this path should exist before
this function is called)
"""
output_folder = os.path.abspath(output_folder)
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
model_name = model.__class__.__name__
for samples in tqdm(data_loader, desc="batches", leave=False, disable=None):
# TB negative labels are skipped (they don't have bboxes)
if samples[1]["label"].item() == 0:
if tb_positive_only:
continue
names = samples[1]["name"]
images = samples[0].to(
device=device, non_blocking=torch.cuda.is_available()
)
if model_name == "DensenetRS" and target_class.lower() == "all":
for target in range(14):
targets = [ClassifierOutputTarget(target)]
visualization_path = f"{output_folder}/{visualization_type}/{rs_maps[target]}/{dataset_split_name}"
os.makedirs(visualization_path, exist_ok=True)
process_target_class(
names, images, targets, visualization_path, cam
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from ..models.densenet import Densenet
from ..models.pasa import Pasa
if isinstance(model, Pasa):
if "fullgrad" in visualisation_types:
raise ValueError(
"Fullgrad visualisation is not supported for the Pasa model."
)
target_layers = [model.fc14] # Last non-1x1 Conv2d layer
elif isinstance(model, Densenet):
target_layers = [
model.model_ft.features.denseblock4.denselayer16.conv2, # type: ignore
]
else:
raise TypeError(f"Model of type `{type(model)}` is not yet supported.")
use_cuda = device_manager.device_type == "cuda"
# prepares model for evaluation, cast to target device
device = device_manager.torch_device()
model = model.to(device)
model.eval()
for vis_type in visualisation_types:
cam = _create_visualiser(vis_type, model, target_layers, use_cuda) # type: ignore
for k, v in datamodule.predict_dataloader().items():
logger.info(
f"Generating saliency maps for dataset `{k}` via `{vis_type}`..."
)
for sample in tqdm.tqdm(
v, desc="samples", leave=False, disable=None
):
name = sample[1]["name"][0]
label = sample[1]["label"].item()
image = sample[0].to(
device=device, non_blocking=torch.cuda.is_available()
)
if model_name == "DensenetRS":
# Get the class with highest activation manually
outputs = cam.activations_and_grads(images)
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
targets = [
ClassifierOutputTarget(category)
for category in target_categories
]
else:
targets = [ClassifierOutputTarget(0)]
visualization_path = f"{output_folder}/{visualization_type}/targeted_class/{dataset_split_name}"
os.makedirs(visualization_path, exist_ok=True)
process_target_class(names, images, targets, visualization_path, cam)
# in binary classification systems, negative labels may be skipped
if positive_only and (model.num_classes == 1) and (label == 0):
continue
# chooses target outputs to generate visualisations for
if model.num_classes > 1:
if target_class == "all":
# just blindly generate visualisations for all outputs
# - make one directory for every target output and lay
# images there like in the original dataset.
for output_num in range(model.num_classes):
use_folder = (
output_folder / vis_type / str(output_num)
)
vis = cam(
input_tensor=image,
targets=[ClassifierOutputTarget(output_num)], # type: ignore
)
_save_visualisation(use_folder, name, vis) # type: ignore
else:
# pytorch-grad-cam will evaluate the output with the
# highest value and produce a visualisation for it - we will
# save it to disk.
use_folder = output_folder / vis_type / "highest-output"
vis = cam(input_tensor=image, targets=None) # type: ignore
_save_visualisation(use_folder, name, vis) # type: ignore
else:
# binary classification model with a single output - just
# lay all cams uniformily like the original dataset
use_folder = output_folder / vis_type
vis = cam(
input_tensor=image,
targets=[
ClassifierOutputTarget(0), # type: ignore
],
)
_save_visualisation(use_folder, name, vis) # type: ignore
......@@ -3,118 +3,16 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import typing
import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from pytorch_grad_cam import (
AblationCAM,
EigenCAM,
EigenGradCAM,
FullGrad,
GradCAM,
GradCAMElementWise,
GradCAMPlusPlus,
HiResCAM,
LayerCAM,
RandomCAM,
ScoreCAM,
XGradCAM,
)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
allowed_visualization_types = {
"gradcam",
"scorecam",
"fullgrad",
"randomcam",
"hirescam",
"gradcamelementwise",
"gradcam++",
"gradcamplusplus",
"xgradcam",
"ablationcam",
"eigencam",
"eigengradcam",
"layercam",
}
# To ensure that the user has selected a supported visualization type
def check_vis_types(vis_types):
if isinstance(vis_types, str):
vis_types = [vis_types.lower()]
else:
vis_types = [s.lower() for s in vis_types]
from ..models.typing import VisualisationType
for s in vis_types:
if not isinstance(s, str):
raise click.BadParameter(
"Visualization type must be a string or a list of strings"
)
if s not in allowed_visualization_types:
raise click.BadParameter(
"Visualization type must be one of: {}".format(
", ".join(allowed_visualization_types)
)
)
return vis_types
# CAM factory
def create_cam(vis_type, model, target_layers, use_cuda):
if vis_type == "gradcam":
return GradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "scorecam":
return ScoreCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "fullgrad":
return FullGrad(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "randomcam":
return RandomCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "hirescam":
return HiResCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "gradcamelementwise":
return GradCAMElementWise(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "gradcam++" or vis_type == "gradcamplusplus":
return GradCAMPlusPlus(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "xgradcam":
return XGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "ablationcam":
return AblationCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "eigencam":
return EigenCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "eigengradcam":
return EigenGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "layercam":
return LayerCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
else:
raise ValueError(f"Unsupported visualization type: {vis_type}")
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.command(
......@@ -123,50 +21,49 @@ def create_cam(vis_type, model, target_layers, use_cuda):
epilog="""Examples:
\b
1. Generates saliency maps and saves them as pickeled objects:
1. Generates saliency maps for all prediction dataloaders on a datamodule,
using a pre-trained DenseNet model, and saves them as numpy-pickeled
objects on the output directory:
.. code:: sh
ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model_final.pth --output-folder=path/to/visualizations
\b
ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model_final.pth --output-folder=path/to/visualisations
""",
)
@click.option(
"--model",
"-m",
help="A lightining module instance implementing the network to be used for inference.",
help="""A lightining module instance implementing the network architecture
(not the weights, necessarily) to be used for inference. Currently, only
supports pasa and densenet models.""",
required=True,
cls=ResourceOption,
)
@click.option(
"--datamodule",
"-d",
help="A lighting data module containing the training, validation and test sets.",
help="""A lighting data module that will be asked for prediction data
loaders. Typically, this includes all configured splits in a datamodule,
however this is not a requirement. A datamodule that returns a single
dataloader for prediction (wrapped in a dictionary) is acceptable.""",
required=True,
cls=ResourceOption,
)
@click.option(
"--output-folder",
"-o",
help="Path where to store the visualizations (created if does not exist)",
help="Path where to store the visualisations (created if does not exist)",
required=True,
type=click.Path(
exists=False,
file_okay=False,
dir_okay=True,
writable=True,
path_type=pathlib.Path,
),
default="visualizations",
cls=ResourceOption,
)
@click.option(
"--batch-size",
"-b",
help="Number of samples in every batch (this parameter affects memory requirements for the network)",
required=True,
show_default=True,
default=1,
type=click.IntRange(min=1),
default="visualisations",
cls=ResourceOption,
)
@click.option(
......@@ -178,6 +75,15 @@ def create_cam(vis_type, model, target_layers, use_cuda):
default="cpu",
cls=ResourceOption,
)
@click.option(
"--cache-samples/--no-cache-samples",
help="If set to True, loads the sample into memory, "
"otherwise loads them at runtime.",
required=True,
show_default=True,
default=False,
cls=ResourceOption,
)
@click.option(
"--weight",
"-w",
......@@ -185,31 +91,53 @@ def create_cam(vis_type, model, target_layers, use_cuda):
corresponding to the architecture set with `--model`.""",
required=True,
cls=ResourceOption,
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True),
)
@click.option(
"--visualization-types",
"--parallel",
"-P",
help="""Use multiprocessing for data loading: if set to -1 (default),
disables multiprocessing data loading. Set to 0 to enable as many data
loading instances as processing cores as available in the system. Set to
>= 1 to enable that many multiprocessing instances for data loading.""",
type=click.IntRange(min=-1),
show_default=True,
required=True,
default=-1,
cls=ResourceOption,
)
@click.option(
"--visualisation-types",
"-vt",
help="Visualization techniques to be used. Can be called multiple times with different techniques. Currently supported ones are: "
'"GradCAM", "ScoreCAM", "FullGrad", "RandomCAM", "HiResCAM", "GradCAMElementWise", "GradCAMPlusPlus", "XGradCAM", "AblationCAM", '
'"EigenCAM", "EigenGradCAM", "LayerCAM"',
help="""visualisation techniques to be used. Can be called multiple times
with different techniques.""",
type=click.Choice(typing.get_args(VisualisationType), case_sensitive=False),
multiple=True,
default=["GradCAM"],
default=["gradcam"],
show_default=True,
cls=ResourceOption,
)
@click.option(
"--target-class",
"-tc",
help='(Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is default, which means only saliency maps for the class with the highest activation will be generated.',
type=str,
help="""This option should only be used with multiclass models. It
defines the class to target for saliency estimation. Can be either set to
"all" or "highest". "highest" (the default), means only saliency maps for
the class with the highest activation will be generated.""",
required=False,
type=click.Choice(
["highest", "all"],
case_sensitive=False,
),
default="highest",
cls=ResourceOption,
)
@click.option(
"--tb-positive-only",
"-tb",
help="If set, saliency maps will only be generated for TB positive samples.",
is_flag=True,
"--positive-only/--no-positive-only",
"-p/-P",
help="""If set, and the model chosen has a single output (binary), then
saliency maps will only be generated for samples of the positive class.
This option has no effect for multiclass models.""",
default=False,
cls=ResourceOption,
)
......@@ -218,69 +146,52 @@ def generate_saliencymaps(
model,
datamodule,
output_folder,
batch_size,
device,
cache_samples,
weight,
visualization_types,
parallel,
visualisation_types,
target_class,
tb_positive_only,
positive_only,
**_,
) -> None:
"""Generates saliency maps for locations with aTB for input CXRs, depending
on visualization technique and model."""
"""Generates saliency maps for locations on input images that affected the
prediction.
The quality of saliency information depends on visualisation
technique and trained model.
"""
from ..engine.device import DeviceManager
from ..engine.saliencymap_generator import run
from .utils import save_sh_command
logger.info(f"Output folder: {output_folder}")
output_folder.mkdir(parents=True, exist_ok=True)
save_sh_command(output_folder / "command.sh")
device_manager = DeviceManager(device)
device = device_manager.torch_device()
use_cuda = device_manager.device_type == "cuda"
datamodule.set_chunk_size(batch_size, 1)
datamodule.drop_incomplete_batch = False
# datamodule.cache_samples = cache_samples
# datamodule.parallel = parallel
# batch_size must be == 1 for now (underlying code is NOT prepared to
# treat multiple samples at once).
datamodule.set_chunk_size(1, 1)
datamodule.cache_samples = cache_samples
datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms
datamodule.prepare_data()
datamodule.setup(stage="predict")
dataloaders = datamodule.predict_dataloader()
logger.info(f"Loading checkpoint from {weight}")
logger.info(f"Loading checkpoint from `{weight}`...")
model = model.load_from_checkpoint(weight, strict=False)
visualization_types = check_vis_types(visualization_types)
model_name = model.__class__.__name__
if model_name == "Pasa":
if "fullgrad" in visualization_types:
raise ValueError(
"Fullgrad visualization is not supported for the Pasa model."
)
target_layers = [model.fc14] # Last non-1x1 Conv2d layer
else:
target_layers = [model.model_ft.features.denseblock4.denselayer16.conv2]
for vis_type in visualization_types:
cam = create_cam(vis_type, model, target_layers, use_cuda)
for k, v in dataloaders.items():
logger.info(f"Generating saliency maps for '{k}' set...")
run(
model,
data_loader=v,
output_folder=output_folder,
dataset_split_name=k,
device=device,
cam=cam,
visualization_type=vis_type,
target_class=target_class,
tb_positive_only=tb_positive_only,
)
run(
model=model,
datamodule=datamodule,
device_manager=device_manager,
visualisation_types=visualisation_types,
target_class=target_class,
positive_only=positive_only,
output_folder=output_folder,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment