diff --git a/src/mednet/engine/saliency/completeness.py b/src/mednet/engine/saliency/completeness.py index 6c8be633336b0e8a5f5150f5d2946fe14cb16fbd..197822ddf28ddc60c14aa039c457e3d4453c5652 100644 --- a/src/mednet/engine/saliency/completeness.py +++ b/src/mednet/engine/saliency/completeness.py @@ -300,7 +300,6 @@ def run( else: raise TypeError(f"Model of type `{type(model)}` is not yet supported.") - use_cuda = device_manager.device_type == "cuda" if device_manager.device_type in ("cuda", "mps") and ( parallel == 0 or parallel > 1 ): @@ -322,7 +321,6 @@ def run( saliency_map_algorithm, model, target_layers, # type: ignore - use_cuda, ) retval: dict[str, list[typing.Any]] = {} diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py index 90f2e433ff03ff07f53fb9eceefaf8084ede7789..df0f9ab123f2ead0b17d361bb5dce73196438184 100644 --- a/src/mednet/engine/saliency/generator.py +++ b/src/mednet/engine/saliency/generator.py @@ -66,7 +66,7 @@ def _create_saliency_map_callable( return pytorch_grad_cam.GradCAMElementWise( model=model, target_layers=target_layers ) - case "gradcam++", "gradcamplusplus": + case "gradcam++" | "gradcamplusplus": return pytorch_grad_cam.GradCAMPlusPlus( model=model, target_layers=target_layers )