Skip to content
Snippets Groups Projects
Commit aea0a2e2 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Predict on all available sets

parent 541cc0a0
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -686,7 +686,9 @@ class CachingDataModule(lightning.LightningDataModule):
def _val_dataset_keys(self) -> list[str]:
"""Returns list of validation dataset names."""
return ["validation"] + [
k for k in self.database_split.keys() if k.startswith("monitor-")
k
for k in self.database_split.subsets.keys()
if k.startswith("monitor-")
]
def setup(self, stage: str) -> None:
......@@ -727,7 +729,8 @@ class CachingDataModule(lightning.LightningDataModule):
self._setup_dataset("test")
elif stage == "predict":
self._setup_dataset("test")
for k in self.database_split.subsets.keys():
self._setup_dataset(k)
def teardown(self, stage: str) -> None:
"""Unset-up datasets for different tasks on the pipeline.
......@@ -814,4 +817,14 @@ class CachingDataModule(lightning.LightningDataModule):
def predict_dataloader(self) -> dict[str, DataLoader]:
"""Returns the prediction data loader(s)"""
return self.test_dataloader()
return {
k: torch.utils.data.DataLoader(
self._datasets[k],
batch_size=self._chunk_size,
shuffle=False,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
**self._dataloader_multiproc,
)
for k in self._datasets
}
......@@ -402,27 +402,28 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
predictions: typing.Sequence[typing.Any],
batch_indices: typing.Sequence[typing.Any] | None,
) -> None:
dataloader_name = list(trainer.datamodule.predict_dataloader().keys())[
0
]
dataloader_names = list(trainer.datamodule.predict_dataloader().keys())
logfile = os.path.join(
self.output_dir, f"predictions_{dataloader_name}", "predictions.csv"
)
os.makedirs(os.path.dirname(logfile), exist_ok=True)
logger.info(f"Saving predictions in {logfile}.")
with open(logfile, "w") as l_f:
logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
logwriter.writeheader()
for prediction in predictions:
logwriter.writerow(
{
"filename": prediction[0],
"likelihood": prediction[1].numpy(),
"ground_truth": prediction[2].numpy(),
}
)
l_f.flush()
for dataloader_idx, dataloader_name in enumerate(dataloader_names):
logfile = os.path.join(
self.output_dir,
f"predictions_{dataloader_name}",
"predictions.csv",
)
os.makedirs(os.path.dirname(logfile), exist_ok=True)
logger.info(f"Saving predictions in {logfile}.")
with open(logfile, "w") as l_f:
logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
logwriter.writeheader()
for prediction in predictions[dataloader_idx]:
logwriter.writerow(
{
"filename": prediction[0],
"likelihood": prediction[1].numpy(),
"ground_truth": prediction[2].numpy(),
}
)
l_f.flush()
......@@ -102,6 +102,8 @@ def predict(
datamodule.set_chunk_size(batch_size, 1)
datamodule.model_transforms = model.model_transforms
datamodule.prepare_data()
datamodule.setup(stage="predict")
logger.info(f"Loading checkpoint from {weight}")
model = model.load_from_checkpoint(weight, strict=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment