From f63421250a8f6a4ee3ee2d170468304c9ec491db Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 19 Feb 2024 17:31:04 +0100 Subject: [PATCH] [engine.saliency.viewer] Fix --show-groundtruth --- src/mednet/engine/saliency/viewer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/mednet/engine/saliency/viewer.py b/src/mednet/engine/saliency/viewer.py index 7caae835..0047c4e1 100644 --- a/src/mednet/engine/saliency/viewer.py +++ b/src/mednet/engine/saliency/viewer.py @@ -226,6 +226,7 @@ def run( for sample in tqdm( dataset_loader, desc="batches", leave=False, disable=None ): + # WARNING: following code assumes a batch size of 1. Will break if not the case. name = str(sample[1]["name"][0]) label = int(sample[1]["label"].item()) data = sample[0][0] @@ -243,7 +244,9 @@ def run( # regions of interest. We need to abstract from this to support more # datasets and other ways to annotate. if show_groundtruth: - ground_truth = sample[1].get("bounding_boxes", BoundingBoxes()) + ground_truth = sample[1].get("bounding_boxes", BoundingBoxes())[ + 0 + ] else: ground_truth = BoundingBoxes() -- GitLab