diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py index e99f78879f6049b6520c1525ef69b2db4f89143e..90f2e433ff03ff07f53fb9eceefaf8084ede7789 100644 --- a/src/mednet/engine/saliency/generator.py +++ b/src/mednet/engine/saliency/generator.py @@ -22,7 +22,6 @@ def _create_saliency_map_callable( algo_type: SaliencyMapAlgorithm, model: torch.nn.Module, target_layers: list[torch.nn.Module] | None, - use_cuda: bool, ): """Create a class activation map (CAM) instance for a given model. @@ -34,8 +33,6 @@ def _create_saliency_map_callable( Neural network model (e.g. pasa). target_layers The target layers to compute CAM for. - use_cuda - Whether to use cuda or not. Returns ------- @@ -47,54 +44,54 @@ def _create_saliency_map_callable( match algo_type: case "gradcam": return pytorch_grad_cam.GradCAM( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case "scorecam": return pytorch_grad_cam.ScoreCAM( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case "fullgrad": return pytorch_grad_cam.FullGrad( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case "randomcam": return pytorch_grad_cam.RandomCAM( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case "hirescam": return pytorch_grad_cam.HiResCAM( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case "gradcamelementwise": return pytorch_grad_cam.GradCAMElementWise( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case "gradcam++", "gradcamplusplus": return pytorch_grad_cam.GradCAMPlusPlus( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case "xgradcam": return pytorch_grad_cam.XGradCAM( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case "ablationcam": assert ( target_layers is not None ), "AblationCAM cannot have target_layers=None" return pytorch_grad_cam.AblationCAM( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case "eigencam": return pytorch_grad_cam.EigenCAM( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case "eigengradcam": return pytorch_grad_cam.EigenGradCAM( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case "layercam": return pytorch_grad_cam.LayerCAM( - model=model, target_layers=target_layers, use_cuda=use_cuda + model=model, target_layers=target_layers ) case _: raise ValueError( @@ -180,8 +177,6 @@ def run( else: raise TypeError(f"Model of type `{type(model)}` is not yet supported.") - use_cuda = device_manager.device_type == "cuda" - # prepares model for evaluation, cast to target device device = device_manager.torch_device() model = model.to(device) @@ -191,7 +186,6 @@ def run( saliency_map_algorithm, model, target_layers, # type: ignore - use_cuda, ) for k, v in datamodule.predict_dataloader().items():