diff --git a/src/mednet/engine/saliency/viewer.py b/src/mednet/engine/saliency/viewer.py index 7caae835739af74d8dac7baf3c38d987824b9350..0047c4e1ad7ce6ee7f9a3cdd6482a234afb11756 100644 --- a/src/mednet/engine/saliency/viewer.py +++ b/src/mednet/engine/saliency/viewer.py @@ -226,6 +226,7 @@ def run( for sample in tqdm( dataset_loader, desc="batches", leave=False, disable=None ): + # WARNING: following code assumes a batch size of 1. Will break if not the case. name = str(sample[1]["name"][0]) label = int(sample[1]["label"].item()) data = sample[0][0] @@ -243,7 +244,9 @@ def run( # regions of interest. We need to abstract from this to support more # datasets and other ways to annotate. if show_groundtruth: - ground_truth = sample[1].get("bounding_boxes", BoundingBoxes()) + ground_truth = sample[1].get("bounding_boxes", BoundingBoxes())[ + 0 + ] else: ground_truth = BoundingBoxes()