Skip to content
Snippets Groups Projects
Commit aea0a2e2 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Predict on all available sets

parent 541cc0a0
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -686,7 +686,9 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -686,7 +686,9 @@ class CachingDataModule(lightning.LightningDataModule):
def _val_dataset_keys(self) -> list[str]: def _val_dataset_keys(self) -> list[str]:
"""Returns list of validation dataset names.""" """Returns list of validation dataset names."""
return ["validation"] + [ return ["validation"] + [
k for k in self.database_split.keys() if k.startswith("monitor-") k
for k in self.database_split.subsets.keys()
if k.startswith("monitor-")
] ]
def setup(self, stage: str) -> None: def setup(self, stage: str) -> None:
...@@ -727,7 +729,8 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -727,7 +729,8 @@ class CachingDataModule(lightning.LightningDataModule):
self._setup_dataset("test") self._setup_dataset("test")
elif stage == "predict": elif stage == "predict":
self._setup_dataset("test") for k in self.database_split.subsets.keys():
self._setup_dataset(k)
def teardown(self, stage: str) -> None: def teardown(self, stage: str) -> None:
"""Unset-up datasets for different tasks on the pipeline. """Unset-up datasets for different tasks on the pipeline.
...@@ -814,4 +817,14 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -814,4 +817,14 @@ class CachingDataModule(lightning.LightningDataModule):
def predict_dataloader(self) -> dict[str, DataLoader]: def predict_dataloader(self) -> dict[str, DataLoader]:
"""Returns the prediction data loader(s)""" """Returns the prediction data loader(s)"""
return self.test_dataloader() return {
k: torch.utils.data.DataLoader(
self._datasets[k],
batch_size=self._chunk_size,
shuffle=False,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
**self._dataloader_multiproc,
)
for k in self._datasets
}
...@@ -402,27 +402,28 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter): ...@@ -402,27 +402,28 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
predictions: typing.Sequence[typing.Any], predictions: typing.Sequence[typing.Any],
batch_indices: typing.Sequence[typing.Any] | None, batch_indices: typing.Sequence[typing.Any] | None,
) -> None: ) -> None:
dataloader_name = list(trainer.datamodule.predict_dataloader().keys())[ dataloader_names = list(trainer.datamodule.predict_dataloader().keys())
0
]
logfile = os.path.join( for dataloader_idx, dataloader_name in enumerate(dataloader_names):
self.output_dir, f"predictions_{dataloader_name}", "predictions.csv" logfile = os.path.join(
) self.output_dir,
os.makedirs(os.path.dirname(logfile), exist_ok=True) f"predictions_{dataloader_name}",
"predictions.csv",
logger.info(f"Saving predictions in {logfile}.") )
os.makedirs(os.path.dirname(logfile), exist_ok=True)
with open(logfile, "w") as l_f:
logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields) logger.info(f"Saving predictions in {logfile}.")
logwriter.writeheader()
with open(logfile, "w") as l_f:
for prediction in predictions: logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
logwriter.writerow( logwriter.writeheader()
{
"filename": prediction[0], for prediction in predictions[dataloader_idx]:
"likelihood": prediction[1].numpy(), logwriter.writerow(
"ground_truth": prediction[2].numpy(), {
} "filename": prediction[0],
) "likelihood": prediction[1].numpy(),
l_f.flush() "ground_truth": prediction[2].numpy(),
}
)
l_f.flush()
...@@ -102,6 +102,8 @@ def predict( ...@@ -102,6 +102,8 @@ def predict(
datamodule.set_chunk_size(batch_size, 1) datamodule.set_chunk_size(batch_size, 1)
datamodule.model_transforms = model.model_transforms datamodule.model_transforms = model.model_transforms
datamodule.prepare_data()
datamodule.setup(stage="predict")
logger.info(f"Loading checkpoint from {weight}") logger.info(f"Loading checkpoint from {weight}")
model = model.load_from_checkpoint(weight, strict=False) model = model.load_from_checkpoint(weight, strict=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment