From f01f65e4c3a4fe682541ed0042f54c5f91321650 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 19 Dec 2023 15:32:00 +0100
Subject: [PATCH] [engine.saliency.viewer] Fix use of heatmap from matplotlib

---
 src/ptbench/engine/saliency/viewer.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/ptbench/engine/saliency/viewer.py b/src/ptbench/engine/saliency/viewer.py
index 43011a38..3c0a7efe 100644
--- a/src/ptbench/engine/saliency/viewer.py
+++ b/src/ptbench/engine/saliency/viewer.py
@@ -102,7 +102,7 @@ def _overlay_saliency_map(
     result = numpy.where(
         saliencies[..., numpy.newaxis] == 0,
         image_array,
-        (image_weight * image_array) + ((1 - image_weight) * heatmap),
+        (image_weight * image_array) + ((1 - image_weight) * heatmap[:, :, :3]),
     )
 
     return PIL.Image.fromarray((result * 255).astype(numpy.uint8), "RGB")
-- 
GitLab