diff --git a/src/mednet/libs/segmentation/engine/predictor.py b/src/mednet/libs/segmentation/engine/predictor.py index a33a3b80cd786ebf17a9220bd513a7597a649655..a7fb4e146582a5c46cd1b249b425f497423fe024 100644 --- a/src/mednet/libs/segmentation/engine/predictor.py +++ b/src/mednet/libs/segmentation/engine/predictor.py @@ -78,7 +78,7 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter): with h5py.File(output_path, "w") as f: f.create_dataset( "image", - data=batch[0][k].cpu().numpy(), + data=batch[0]["image"][k].cpu().numpy(), compression="gzip", compression_opts=9, ) @@ -90,13 +90,13 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter): ) f.create_dataset( "target", - data=(batch[1]["target"][k].squeeze(0).cpu().numpy() > 0.5), + data=(batch[0]["target"][k].squeeze(0).cpu().numpy() > 0.5), compression="gzip", compression_opts=9, ) f.create_dataset( "mask", - data=(batch[1]["mask"][k].squeeze(0).cpu().numpy() > 0.5), + data=(batch[0]["mask"][k].squeeze(0).cpu().numpy() > 0.5), compression="gzip", compression_opts=9, )