diff --git a/src/ptbench/engine/saliency/viewer.py b/src/ptbench/engine/saliency/viewer.py index 43011a385698f2a2dc3d2b19b256a2b4c1f7d1b2..3c0a7efe1300e069c47189274fc556f177050759 100644 --- a/src/ptbench/engine/saliency/viewer.py +++ b/src/ptbench/engine/saliency/viewer.py @@ -102,7 +102,7 @@ def _overlay_saliency_map( result = numpy.where( saliencies[..., numpy.newaxis] == 0, image_array, - (image_weight * image_array) + ((1 - image_weight) * heatmap), + (image_weight * image_array) + ((1 - image_weight) * heatmap[:, :, :3]), ) return PIL.Image.fromarray((result * 255).astype(numpy.uint8), "RGB")