From f63421250a8f6a4ee3ee2d170468304c9ec491db Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 19 Feb 2024 17:31:04 +0100
Subject: [PATCH] [engine.saliency.viewer] Fix --show-groundtruth

---
 src/mednet/engine/saliency/viewer.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/src/mednet/engine/saliency/viewer.py b/src/mednet/engine/saliency/viewer.py
index 7caae835..0047c4e1 100644
--- a/src/mednet/engine/saliency/viewer.py
+++ b/src/mednet/engine/saliency/viewer.py
@@ -226,6 +226,7 @@ def run(
         for sample in tqdm(
             dataset_loader, desc="batches", leave=False, disable=None
         ):
+            # WARNING: following code assumes a batch size of 1. Will break if not the case.
             name = str(sample[1]["name"][0])
             label = int(sample[1]["label"].item())
             data = sample[0][0]
@@ -243,7 +244,9 @@ def run(
             # regions of interest.  We need to abstract from this to support more
             # datasets and other ways to annotate.
             if show_groundtruth:
-                ground_truth = sample[1].get("bounding_boxes", BoundingBoxes())
+                ground_truth = sample[1].get("bounding_boxes", BoundingBoxes())[
+                    0
+                ]
             else:
                 ground_truth = BoundingBoxes()
 
-- 
GitLab