diff --git a/src/ptbench/engine/saliency/viewer.py b/src/ptbench/engine/saliency/viewer.py
index 43011a385698f2a2dc3d2b19b256a2b4c1f7d1b2..3c0a7efe1300e069c47189274fc556f177050759 100644
--- a/src/ptbench/engine/saliency/viewer.py
+++ b/src/ptbench/engine/saliency/viewer.py
@@ -102,7 +102,7 @@ def _overlay_saliency_map(
     result = numpy.where(
         saliencies[..., numpy.newaxis] == 0,
         image_array,
-        (image_weight * image_array) + ((1 - image_weight) * heatmap),
+        (image_weight * image_array) + ((1 - image_weight) * heatmap[:, :, :3]),
     )
 
     return PIL.Image.fromarray((result * 255).astype(numpy.uint8), "RGB")