From 7e66943c7ab7ac7c1980a2290eaf9655818e21df Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 26 Jun 2024 10:49:50 +0200
Subject: [PATCH] [segmentation.predict] Make prediction work with new samples

---
 src/mednet/libs/segmentation/engine/predictor.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/mednet/libs/segmentation/engine/predictor.py b/src/mednet/libs/segmentation/engine/predictor.py
index a33a3b80..a7fb4e14 100644
--- a/src/mednet/libs/segmentation/engine/predictor.py
+++ b/src/mednet/libs/segmentation/engine/predictor.py
@@ -78,7 +78,7 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter):
             with h5py.File(output_path, "w") as f:
                 f.create_dataset(
                     "image",
-                    data=batch[0][k].cpu().numpy(),
+                    data=batch[0]["image"][k].cpu().numpy(),
                     compression="gzip",
                     compression_opts=9,
                 )
@@ -90,13 +90,13 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter):
                 )
                 f.create_dataset(
                     "target",
-                    data=(batch[1]["target"][k].squeeze(0).cpu().numpy() > 0.5),
+                    data=(batch[0]["target"][k].squeeze(0).cpu().numpy() > 0.5),
                     compression="gzip",
                     compression_opts=9,
                 )
                 f.create_dataset(
                     "mask",
-                    data=(batch[1]["mask"][k].squeeze(0).cpu().numpy() > 0.5),
+                    data=(batch[0]["mask"][k].squeeze(0).cpu().numpy() > 0.5),
                     compression="gzip",
                     compression_opts=9,
                 )
-- 
GitLab