diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py
index e99f78879f6049b6520c1525ef69b2db4f89143e..90f2e433ff03ff07f53fb9eceefaf8084ede7789 100644
--- a/src/mednet/engine/saliency/generator.py
+++ b/src/mednet/engine/saliency/generator.py
@@ -22,7 +22,6 @@ def _create_saliency_map_callable(
     algo_type: SaliencyMapAlgorithm,
     model: torch.nn.Module,
     target_layers: list[torch.nn.Module] | None,
-    use_cuda: bool,
 ):
     """Create a class activation map (CAM) instance for a given model.
 
@@ -34,8 +33,6 @@ def _create_saliency_map_callable(
         Neural network model (e.g. pasa).
     target_layers
         The target layers to compute CAM for.
-    use_cuda
-        Whether to use cuda or not.
 
     Returns
     -------
@@ -47,54 +44,54 @@ def _create_saliency_map_callable(
     match algo_type:
         case "gradcam":
             return pytorch_grad_cam.GradCAM(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case "scorecam":
             return pytorch_grad_cam.ScoreCAM(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case "fullgrad":
             return pytorch_grad_cam.FullGrad(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case "randomcam":
             return pytorch_grad_cam.RandomCAM(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case "hirescam":
             return pytorch_grad_cam.HiResCAM(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case "gradcamelementwise":
             return pytorch_grad_cam.GradCAMElementWise(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case "gradcam++", "gradcamplusplus":
             return pytorch_grad_cam.GradCAMPlusPlus(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case "xgradcam":
             return pytorch_grad_cam.XGradCAM(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case "ablationcam":
             assert (
                 target_layers is not None
             ), "AblationCAM cannot have target_layers=None"
             return pytorch_grad_cam.AblationCAM(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case "eigencam":
             return pytorch_grad_cam.EigenCAM(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case "eigengradcam":
             return pytorch_grad_cam.EigenGradCAM(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case "layercam":
             return pytorch_grad_cam.LayerCAM(
-                model=model, target_layers=target_layers, use_cuda=use_cuda
+                model=model, target_layers=target_layers
             )
         case _:
             raise ValueError(
@@ -180,8 +177,6 @@ def run(
     else:
         raise TypeError(f"Model of type `{type(model)}` is not yet supported.")
 
-    use_cuda = device_manager.device_type == "cuda"
-
     # prepares model for evaluation, cast to target device
     device = device_manager.torch_device()
     model = model.to(device)
@@ -191,7 +186,6 @@ def run(
         saliency_map_algorithm,
         model,
         target_layers,  # type: ignore
-        use_cuda,
     )
 
     for k, v in datamodule.predict_dataloader().items():