Skip to content
Snippets Groups Projects
Commit e1f31348 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[saliency] Remove use-cuda parameter

Cuda usage is now inferred from the model.
parent 86751c88
No related branches found
No related tags found
1 merge request!20Update dependence to lightning (closes #61)
Pipeline #84389 passed
......@@ -22,7 +22,6 @@ def _create_saliency_map_callable(
algo_type: SaliencyMapAlgorithm,
model: torch.nn.Module,
target_layers: list[torch.nn.Module] | None,
use_cuda: bool,
):
"""Create a class activation map (CAM) instance for a given model.
......@@ -34,8 +33,6 @@ def _create_saliency_map_callable(
Neural network model (e.g. pasa).
target_layers
The target layers to compute CAM for.
use_cuda
Whether to use cuda or not.
Returns
-------
......@@ -47,54 +44,54 @@ def _create_saliency_map_callable(
match algo_type:
case "gradcam":
return pytorch_grad_cam.GradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case "scorecam":
return pytorch_grad_cam.ScoreCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case "fullgrad":
return pytorch_grad_cam.FullGrad(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case "randomcam":
return pytorch_grad_cam.RandomCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case "hirescam":
return pytorch_grad_cam.HiResCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case "gradcamelementwise":
return pytorch_grad_cam.GradCAMElementWise(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case "gradcam++", "gradcamplusplus":
return pytorch_grad_cam.GradCAMPlusPlus(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case "xgradcam":
return pytorch_grad_cam.XGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case "ablationcam":
assert (
target_layers is not None
), "AblationCAM cannot have target_layers=None"
return pytorch_grad_cam.AblationCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case "eigencam":
return pytorch_grad_cam.EigenCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case "eigengradcam":
return pytorch_grad_cam.EigenGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case "layercam":
return pytorch_grad_cam.LayerCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
model=model, target_layers=target_layers
)
case _:
raise ValueError(
......@@ -180,8 +177,6 @@ def run(
else:
raise TypeError(f"Model of type `{type(model)}` is not yet supported.")
use_cuda = device_manager.device_type == "cuda"
# prepares model for evaluation, cast to target device
device = device_manager.torch_device()
model = model.to(device)
......@@ -191,7 +186,6 @@ def run(
saliency_map_algorithm,
model,
target_layers, # type: ignore
use_cuda,
)
for k, v in datamodule.predict_dataloader().items():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment