From 7e66943c7ab7ac7c1980a2290eaf9655818e21df Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 26 Jun 2024 10:49:50 +0200 Subject: [PATCH] [segmentation.predict] Make prediction work with new samples --- src/mednet/libs/segmentation/engine/predictor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mednet/libs/segmentation/engine/predictor.py b/src/mednet/libs/segmentation/engine/predictor.py index a33a3b80..a7fb4e14 100644 --- a/src/mednet/libs/segmentation/engine/predictor.py +++ b/src/mednet/libs/segmentation/engine/predictor.py @@ -78,7 +78,7 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter): with h5py.File(output_path, "w") as f: f.create_dataset( "image", - data=batch[0][k].cpu().numpy(), + data=batch[0]["image"][k].cpu().numpy(), compression="gzip", compression_opts=9, ) @@ -90,13 +90,13 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter): ) f.create_dataset( "target", - data=(batch[1]["target"][k].squeeze(0).cpu().numpy() > 0.5), + data=(batch[0]["target"][k].squeeze(0).cpu().numpy() > 0.5), compression="gzip", compression_opts=9, ) f.create_dataset( "mask", - data=(batch[1]["mask"][k].squeeze(0).cpu().numpy() > 0.5), + data=(batch[0]["mask"][k].squeeze(0).cpu().numpy() > 0.5), compression="gzip", compression_opts=9, ) -- GitLab