Skip to content
Snippets Groups Projects
Commit 7e66943c authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation.predict] Make prediction work with new samples

parent 411a9285
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -78,7 +78,7 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter): ...@@ -78,7 +78,7 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter):
with h5py.File(output_path, "w") as f: with h5py.File(output_path, "w") as f:
f.create_dataset( f.create_dataset(
"image", "image",
data=batch[0][k].cpu().numpy(), data=batch[0]["image"][k].cpu().numpy(),
compression="gzip", compression="gzip",
compression_opts=9, compression_opts=9,
) )
...@@ -90,13 +90,13 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter): ...@@ -90,13 +90,13 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter):
) )
f.create_dataset( f.create_dataset(
"target", "target",
data=(batch[1]["target"][k].squeeze(0).cpu().numpy() > 0.5), data=(batch[0]["target"][k].squeeze(0).cpu().numpy() > 0.5),
compression="gzip", compression="gzip",
compression_opts=9, compression_opts=9,
) )
f.create_dataset( f.create_dataset(
"mask", "mask",
data=(batch[1]["mask"][k].squeeze(0).cpu().numpy() > 0.5), data=(batch[0]["mask"][k].squeeze(0).cpu().numpy() > 0.5),
compression="gzip", compression="gzip",
compression_opts=9, compression_opts=9,
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment