diff --git a/src/mednet/libs/segmentation/engine/predictor.py b/src/mednet/libs/segmentation/engine/predictor.py
index a33a3b80cd786ebf17a9220bd513a7597a649655..a7fb4e146582a5c46cd1b249b425f497423fe024 100644
--- a/src/mednet/libs/segmentation/engine/predictor.py
+++ b/src/mednet/libs/segmentation/engine/predictor.py
@@ -78,7 +78,7 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter):
             with h5py.File(output_path, "w") as f:
                 f.create_dataset(
                     "image",
-                    data=batch[0][k].cpu().numpy(),
+                    data=batch[0]["image"][k].cpu().numpy(),
                     compression="gzip",
                     compression_opts=9,
                 )
@@ -90,13 +90,13 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter):
                 )
                 f.create_dataset(
                     "target",
-                    data=(batch[1]["target"][k].squeeze(0).cpu().numpy() > 0.5),
+                    data=(batch[0]["target"][k].squeeze(0).cpu().numpy() > 0.5),
                     compression="gzip",
                     compression_opts=9,
                 )
                 f.create_dataset(
                     "mask",
-                    data=(batch[1]["mask"][k].squeeze(0).cpu().numpy() > 0.5),
+                    data=(batch[0]["mask"][k].squeeze(0).cpu().numpy() > 0.5),
                     compression="gzip",
                     compression_opts=9,
                 )