Skip to content
Snippets Groups Projects
Commit 3920e964 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine,scripts] Simplify saliencymap generator; Add more type hinting; Make...

[engine,scripts] Simplify saliencymap generator; Add more type hinting; Make it TB-agnostic; Closes #49
parent 8844778e
No related branches found
No related tags found
No related merge requests found
Pipeline #78624 failed
...@@ -3,149 +3,214 @@ ...@@ -3,149 +3,214 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import logging import logging
import os import pathlib
import typing
import numpy as np import lightning.pytorch
import numpy
import torch import torch
import torch.nn
import tqdm
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from ..models.typing import VisualisationType
from tqdm import tqdm from .device import DeviceManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
rs_maps = {
0: "cardiomegaly",
1: "emphysema",
2: "effusion",
3: "hernia",
4: "infiltration",
5: "mass",
6: "nodule",
7: "atelectasis",
8: "pneumothorax",
9: "pleural thickening",
10: "pneumonia",
11: "fibrosis",
12: "edema",
13: "consolidation",
}
def _create_visualiser(
def _save_npy(img_name_stem, grayscale_cam, visualization_path): vis_type: VisualisationType,
image_path = f"{visualization_path}/{img_name_stem}" model: torch.nn.Module,
os.makedirs(os.path.dirname(image_path), exist_ok=True) target_layers: list[torch.nn.Module] | None,
use_cuda: bool,
np.save(image_path, grayscale_cam) ):
"""Creates a class activation map (CAM) instance for a given model."""
# Helper function to calculate saliency maps for a single target class import pytorch_grad_cam
# of a single input image.
def process_target_class(names, images, targets, visualization_path, cam): match vis_type:
grayscale_cams = cam(input_tensor=images, targets=targets) case "gradcam":
return pytorch_grad_cam.GradCAM(
for i, grayscale_cam in enumerate(grayscale_cams): model=model, target_layers=target_layers, use_cuda=use_cuda
img_name_stem = names[i].split(".")[0] )
case "scorecam":
_save_npy(img_name_stem, grayscale_cam, visualization_path) return pytorch_grad_cam.ScoreCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "fullgrad":
return pytorch_grad_cam.FullGrad(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "randomcam":
return pytorch_grad_cam.RandomCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "hirescam":
return pytorch_grad_cam.HiResCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "gradcamelementwise":
return pytorch_grad_cam.GradCAMElementWise(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "gradcam++", "gradcamplusplus":
return pytorch_grad_cam.GradCAMPlusPlus(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "xgradcam":
return pytorch_grad_cam.XGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "ablationcam":
assert (
target_layers is not None
), "AblationCAM cannot have target_layers=None"
return pytorch_grad_cam.AblationCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "eigencam":
return pytorch_grad_cam.EigenCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "eigengradcam":
return pytorch_grad_cam.EigenGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "layercam":
return pytorch_grad_cam.LayerCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case _:
raise ValueError(
f"Visualisation type `{vis_type}` is not currently supported."
)
def _save_visualisation(
output_folder: pathlib.Path, name: str, vis: torch.Tensor
) -> None:
"""Helper function to save a saliency map to disk."""
n = pathlib.Path(name)
(output_folder / n.parent).mkdir(parents=True, exist_ok=True)
numpy.save(output_folder / n.with_suffix(".npy"), vis)
def run( def run(
model, model: lightning.pytorch.LightningModule,
data_loader, datamodule: lightning.pytorch.LightningDataModule,
output_folder, device_manager: DeviceManager,
dataset_split_name, visualisation_types: typing.Sequence[VisualisationType],
device, target_class: typing.Literal["highest", "all"],
cam, positive_only: bool,
visualization_type, output_folder: pathlib.Path,
target_class="highest", ) -> None:
tb_positive_only=True, """Applies visualisation techniques on input CXR, outputs pickled saliency
): maps directly to disk.
"""Applies visualization techniques on input CXR, outputs pickled saliency
maps.
Parameters Parameters
--------- ---------
model model
Neural network model (e.g. pasa). Neural network model (e.g. pasa).
datamodule
data_loader The lightning datamodule to use for training **and** validation
The pytorch lightning Dataloader used to iterate over batches. device_manager
An internal device representation, to be used for training and
output_folder : str validation. This representation can be converted into a pytorch device
Directory in which the results will be saved. or a torch lightning accelerator setup.
visualisation_types
dataset_split_name : str The types of visualisations to generate for the current model and
Name of the dataset split (e.g. "train", "validation", "test"). datamodule.
target_class
device : str (Use only with multi-label models) Which class to target for CAM
A string indicating the device to use (e.g. "cpu" or "cuda"). The device can also be specified (cuda:0) calculation. Can be either set to "all" or "highest". "highest" is
default, which means only saliency maps for the class with the highest
cam : py:class: `pytorch_grad_cam.GradCAM`, `pytorch_grad_cam.ScoreCAM`, activation will be generated.
`pytorch_grad_cam.FullGrad`, `pytorch_grad_cam.RandomCAM`, positive_only
`pytorch_grad_cam.EigenCAM`, `pytorch_grad_cam.EigenGradCAM`, If set, saliency maps will only be generated for positive samples (ie.
`pytorch_grad_cam.LayerCAM`, `pytorch_grad_cam.XGradCAM`, label == 1 in a binary classification task). This option is ignored on
`pytorch_grad_cam.AblationCAM`, `pytorch_grad_cam.HiResCAM`, a multi-class output model.
`pytorch_grad_cam.GradCAMElementWise`, `pytorch_grad_cam.GradCAMplusplus`, output_folder
The CAM object to use for visualization. Where to save all the visualisations (this path should exist before
this function is called)
visualization_types : list
Type of visualization techniques to be applied. Possible values are:
"GradCAM", "ScoreCAM", "FullGrad", "RandomCAM", "HiResCAM", "GradCAMElementWise", "GradCAMPlusPlus", "XGradCAM", "AblationCAM",
"EigenCAM", "EigenGradCAM", "LayerCAM".
target_class : str
(Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is default, which means only saliency maps for the class with the highest activation will be generated.
tb_positive_only : bool
If set, saliency maps will only be generated for TB positive samples.
Returns
-------
all_predictions : list
All the predictions associated with filename and ground truth, saved as .csv.
""" """
output_folder = os.path.abspath(output_folder) from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
logger.info(f"Output folder: {output_folder}") from ..models.densenet import Densenet
os.makedirs(output_folder, exist_ok=True) from ..models.pasa import Pasa
model_name = model.__class__.__name__ if isinstance(model, Pasa):
if "fullgrad" in visualisation_types:
for samples in tqdm(data_loader, desc="batches", leave=False, disable=None): raise ValueError(
# TB negative labels are skipped (they don't have bboxes) "Fullgrad visualisation is not supported for the Pasa model."
if samples[1]["label"].item() == 0: )
if tb_positive_only: target_layers = [model.fc14] # Last non-1x1 Conv2d layer
continue elif isinstance(model, Densenet):
target_layers = [
names = samples[1]["name"] model.model_ft.features.denseblock4.denselayer16.conv2, # type: ignore
images = samples[0].to( ]
device=device, non_blocking=torch.cuda.is_available() else:
) raise TypeError(f"Model of type `{type(model)}` is not yet supported.")
if model_name == "DensenetRS" and target_class.lower() == "all": use_cuda = device_manager.device_type == "cuda"
for target in range(14):
targets = [ClassifierOutputTarget(target)] # prepares model for evaluation, cast to target device
device = device_manager.torch_device()
visualization_path = f"{output_folder}/{visualization_type}/{rs_maps[target]}/{dataset_split_name}" model = model.to(device)
os.makedirs(visualization_path, exist_ok=True) model.eval()
process_target_class( for vis_type in visualisation_types:
names, images, targets, visualization_path, cam cam = _create_visualiser(vis_type, model, target_layers, use_cuda) # type: ignore
for k, v in datamodule.predict_dataloader().items():
logger.info(
f"Generating saliency maps for dataset `{k}` via `{vis_type}`..."
)
for sample in tqdm.tqdm(
v, desc="samples", leave=False, disable=None
):
name = sample[1]["name"][0]
label = sample[1]["label"].item()
image = sample[0].to(
device=device, non_blocking=torch.cuda.is_available()
) )
if model_name == "DensenetRS": # in binary classification systems, negative labels may be skipped
# Get the class with highest activation manually if positive_only and (model.num_classes == 1) and (label == 0):
outputs = cam.activations_and_grads(images) continue
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
targets = [ # chooses target outputs to generate visualisations for
ClassifierOutputTarget(category) if model.num_classes > 1:
for category in target_categories if target_class == "all":
] # just blindly generate visualisations for all outputs
else: # - make one directory for every target output and lay
targets = [ClassifierOutputTarget(0)] # images there like in the original dataset.
for output_num in range(model.num_classes):
visualization_path = f"{output_folder}/{visualization_type}/targeted_class/{dataset_split_name}" use_folder = (
os.makedirs(visualization_path, exist_ok=True) output_folder / vis_type / str(output_num)
)
process_target_class(names, images, targets, visualization_path, cam) vis = cam(
input_tensor=image,
targets=[ClassifierOutputTarget(output_num)], # type: ignore
)
_save_visualisation(use_folder, name, vis) # type: ignore
else:
# pytorch-grad-cam will evaluate the output with the
# highest value and produce a visualisation for it - we will
# save it to disk.
use_folder = output_folder / vis_type / "highest-output"
vis = cam(input_tensor=image, targets=None) # type: ignore
_save_visualisation(use_folder, name, vis) # type: ignore
else:
# binary classification model with a single output - just
# lay all cams uniformily like the original dataset
use_folder = output_folder / vis_type
vis = cam(
input_tensor=image,
targets=[
ClassifierOutputTarget(0), # type: ignore
],
)
_save_visualisation(use_folder, name, vis) # type: ignore
...@@ -3,118 +3,16 @@ ...@@ -3,118 +3,16 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import pathlib import pathlib
import typing
import click import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from pytorch_grad_cam import (
AblationCAM,
EigenCAM,
EigenGradCAM,
FullGrad,
GradCAM,
GradCAMElementWise,
GradCAMPlusPlus,
HiResCAM,
LayerCAM,
RandomCAM,
ScoreCAM,
XGradCAM,
)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
allowed_visualization_types = {
"gradcam",
"scorecam",
"fullgrad",
"randomcam",
"hirescam",
"gradcamelementwise",
"gradcam++",
"gradcamplusplus",
"xgradcam",
"ablationcam",
"eigencam",
"eigengradcam",
"layercam",
}
# To ensure that the user has selected a supported visualization type from ..models.typing import VisualisationType
def check_vis_types(vis_types):
if isinstance(vis_types, str):
vis_types = [vis_types.lower()]
else:
vis_types = [s.lower() for s in vis_types]
for s in vis_types: logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
if not isinstance(s, str):
raise click.BadParameter(
"Visualization type must be a string or a list of strings"
)
if s not in allowed_visualization_types:
raise click.BadParameter(
"Visualization type must be one of: {}".format(
", ".join(allowed_visualization_types)
)
)
return vis_types
# CAM factory
def create_cam(vis_type, model, target_layers, use_cuda):
if vis_type == "gradcam":
return GradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "scorecam":
return ScoreCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "fullgrad":
return FullGrad(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "randomcam":
return RandomCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "hirescam":
return HiResCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "gradcamelementwise":
return GradCAMElementWise(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "gradcam++" or vis_type == "gradcamplusplus":
return GradCAMPlusPlus(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "xgradcam":
return XGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "ablationcam":
return AblationCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "eigencam":
return EigenCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "eigengradcam":
return EigenGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
elif vis_type == "layercam":
return LayerCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
else:
raise ValueError(f"Unsupported visualization type: {vis_type}")
@click.command( @click.command(
...@@ -123,50 +21,49 @@ def create_cam(vis_type, model, target_layers, use_cuda): ...@@ -123,50 +21,49 @@ def create_cam(vis_type, model, target_layers, use_cuda):
epilog="""Examples: epilog="""Examples:
\b \b
1. Generates saliency maps and saves them as pickeled objects: 1. Generates saliency maps for all prediction dataloaders on a datamodule,
using a pre-trained DenseNet model, and saves them as numpy-pickeled
objects on the output directory:
.. code:: sh .. code:: sh
ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model_final.pth --output-folder=path/to/visualizations \b
ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model_final.pth --output-folder=path/to/visualisations
""", """,
) )
@click.option( @click.option(
"--model", "--model",
"-m", "-m",
help="A lightining module instance implementing the network to be used for inference.", help="""A lightining module instance implementing the network architecture
(not the weights, necessarily) to be used for inference. Currently, only
supports pasa and densenet models.""",
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--datamodule", "--datamodule",
"-d", "-d",
help="A lighting data module containing the training, validation and test sets.", help="""A lighting data module that will be asked for prediction data
loaders. Typically, this includes all configured splits in a datamodule,
however this is not a requirement. A datamodule that returns a single
dataloader for prediction (wrapped in a dictionary) is acceptable.""",
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--output-folder", "--output-folder",
"-o", "-o",
help="Path where to store the visualizations (created if does not exist)", help="Path where to store the visualisations (created if does not exist)",
required=True, required=True,
type=click.Path( type=click.Path(
exists=False,
file_okay=False, file_okay=False,
dir_okay=True, dir_okay=True,
writable=True, writable=True,
path_type=pathlib.Path, path_type=pathlib.Path,
), ),
default="visualizations", default="visualisations",
cls=ResourceOption,
)
@click.option(
"--batch-size",
"-b",
help="Number of samples in every batch (this parameter affects memory requirements for the network)",
required=True,
show_default=True,
default=1,
type=click.IntRange(min=1),
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
...@@ -178,6 +75,15 @@ def create_cam(vis_type, model, target_layers, use_cuda): ...@@ -178,6 +75,15 @@ def create_cam(vis_type, model, target_layers, use_cuda):
default="cpu", default="cpu",
cls=ResourceOption, cls=ResourceOption,
) )
@click.option(
"--cache-samples/--no-cache-samples",
help="If set to True, loads the sample into memory, "
"otherwise loads them at runtime.",
required=True,
show_default=True,
default=False,
cls=ResourceOption,
)
@click.option( @click.option(
"--weight", "--weight",
"-w", "-w",
...@@ -185,31 +91,53 @@ def create_cam(vis_type, model, target_layers, use_cuda): ...@@ -185,31 +91,53 @@ def create_cam(vis_type, model, target_layers, use_cuda):
corresponding to the architecture set with `--model`.""", corresponding to the architecture set with `--model`.""",
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True),
) )
@click.option( @click.option(
"--visualization-types", "--parallel",
"-P",
help="""Use multiprocessing for data loading: if set to -1 (default),
disables multiprocessing data loading. Set to 0 to enable as many data
loading instances as processing cores as available in the system. Set to
>= 1 to enable that many multiprocessing instances for data loading.""",
type=click.IntRange(min=-1),
show_default=True,
required=True,
default=-1,
cls=ResourceOption,
)
@click.option(
"--visualisation-types",
"-vt", "-vt",
help="Visualization techniques to be used. Can be called multiple times with different techniques. Currently supported ones are: " help="""visualisation techniques to be used. Can be called multiple times
'"GradCAM", "ScoreCAM", "FullGrad", "RandomCAM", "HiResCAM", "GradCAMElementWise", "GradCAMPlusPlus", "XGradCAM", "AblationCAM", ' with different techniques.""",
'"EigenCAM", "EigenGradCAM", "LayerCAM"', type=click.Choice(typing.get_args(VisualisationType), case_sensitive=False),
multiple=True, multiple=True,
default=["GradCAM"], default=["gradcam"],
show_default=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--target-class", "--target-class",
"-tc", "-tc",
help='(Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is default, which means only saliency maps for the class with the highest activation will be generated.', help="""This option should only be used with multiclass models. It
type=str, defines the class to target for saliency estimation. Can be either set to
"all" or "highest". "highest" (the default), means only saliency maps for
the class with the highest activation will be generated.""",
required=False, required=False,
type=click.Choice(
["highest", "all"],
case_sensitive=False,
),
default="highest", default="highest",
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--tb-positive-only", "--positive-only/--no-positive-only",
"-tb", "-p/-P",
help="If set, saliency maps will only be generated for TB positive samples.", help="""If set, and the model chosen has a single output (binary), then
is_flag=True, saliency maps will only be generated for samples of the positive class.
This option has no effect for multiclass models.""",
default=False, default=False,
cls=ResourceOption, cls=ResourceOption,
) )
...@@ -218,69 +146,52 @@ def generate_saliencymaps( ...@@ -218,69 +146,52 @@ def generate_saliencymaps(
model, model,
datamodule, datamodule,
output_folder, output_folder,
batch_size,
device, device,
cache_samples,
weight, weight,
visualization_types, parallel,
visualisation_types,
target_class, target_class,
tb_positive_only, positive_only,
**_, **_,
) -> None: ) -> None:
"""Generates saliency maps for locations with aTB for input CXRs, depending """Generates saliency maps for locations on input images that affected the
on visualization technique and model.""" prediction.
The quality of saliency information depends on visualisation
technique and trained model.
"""
from ..engine.device import DeviceManager from ..engine.device import DeviceManager
from ..engine.saliencymap_generator import run from ..engine.saliencymap_generator import run
from .utils import save_sh_command from .utils import save_sh_command
logger.info(f"Output folder: {output_folder}")
output_folder.mkdir(parents=True, exist_ok=True)
save_sh_command(output_folder / "command.sh") save_sh_command(output_folder / "command.sh")
device_manager = DeviceManager(device) device_manager = DeviceManager(device)
device = device_manager.torch_device()
use_cuda = device_manager.device_type == "cuda"
datamodule.set_chunk_size(batch_size, 1) # batch_size must be == 1 for now (underlying code is NOT prepared to
datamodule.drop_incomplete_batch = False # treat multiple samples at once).
# datamodule.cache_samples = cache_samples datamodule.set_chunk_size(1, 1)
# datamodule.parallel = parallel datamodule.cache_samples = cache_samples
datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms datamodule.model_transforms = model.model_transforms
datamodule.prepare_data() datamodule.prepare_data()
datamodule.setup(stage="predict") datamodule.setup(stage="predict")
dataloaders = datamodule.predict_dataloader() logger.info(f"Loading checkpoint from `{weight}`...")
logger.info(f"Loading checkpoint from {weight}")
model = model.load_from_checkpoint(weight, strict=False) model = model.load_from_checkpoint(weight, strict=False)
visualization_types = check_vis_types(visualization_types) run(
model=model,
model_name = model.__class__.__name__ datamodule=datamodule,
device_manager=device_manager,
if model_name == "Pasa": visualisation_types=visualisation_types,
if "fullgrad" in visualization_types: target_class=target_class,
raise ValueError( positive_only=positive_only,
"Fullgrad visualization is not supported for the Pasa model." output_folder=output_folder,
) )
target_layers = [model.fc14] # Last non-1x1 Conv2d layer
else:
target_layers = [model.model_ft.features.denseblock4.denselayer16.conv2]
for vis_type in visualization_types:
cam = create_cam(vis_type, model, target_layers, use_cuda)
for k, v in dataloaders.items():
logger.info(f"Generating saliency maps for '{k}' set...")
run(
model,
data_loader=v,
output_folder=output_folder,
dataset_split_name=k,
device=device,
cam=cam,
visualization_type=vis_type,
target_class=target_class,
tb_positive_only=tb_positive_only,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment