diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index bb2dcdda8ad100d91f6bbd35abd8bcb07e41a334..af0d513ac1bd166179f64fb759bf9e422da8bf11 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -686,7 +686,9 @@ class CachingDataModule(lightning.LightningDataModule): def _val_dataset_keys(self) -> list[str]: """Returns list of validation dataset names.""" return ["validation"] + [ - k for k in self.database_split.keys() if k.startswith("monitor-") + k + for k in self.database_split.subsets.keys() + if k.startswith("monitor-") ] def setup(self, stage: str) -> None: @@ -727,7 +729,8 @@ class CachingDataModule(lightning.LightningDataModule): self._setup_dataset("test") elif stage == "predict": - self._setup_dataset("test") + for k in self.database_split.subsets.keys(): + self._setup_dataset(k) def teardown(self, stage: str) -> None: """Unset-up datasets for different tasks on the pipeline. @@ -814,4 +817,14 @@ class CachingDataModule(lightning.LightningDataModule): def predict_dataloader(self) -> dict[str, DataLoader]: """Returns the prediction data loader(s)""" - return self.test_dataloader() + return { + k: torch.utils.data.DataLoader( + self._datasets[k], + batch_size=self._chunk_size, + shuffle=False, + drop_last=self.drop_incomplete_batch, + pin_memory=self.pin_memory, + **self._dataloader_multiproc, + ) + for k in self._datasets + } diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index 031acaaeb57855893ee99678381ade8d2a38f57c..c2ce035f73a37fad996181adb5d9412fb8e26556 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -402,27 +402,28 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter): predictions: typing.Sequence[typing.Any], batch_indices: typing.Sequence[typing.Any] | None, ) -> None: - dataloader_name = list(trainer.datamodule.predict_dataloader().keys())[ - 0 - ] + dataloader_names = list(trainer.datamodule.predict_dataloader().keys()) - logfile = os.path.join( - self.output_dir, f"predictions_{dataloader_name}", "predictions.csv" - ) - os.makedirs(os.path.dirname(logfile), exist_ok=True) - - logger.info(f"Saving predictions in {logfile}.") - - with open(logfile, "w") as l_f: - logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields) - logwriter.writeheader() - - for prediction in predictions: - logwriter.writerow( - { - "filename": prediction[0], - "likelihood": prediction[1].numpy(), - "ground_truth": prediction[2].numpy(), - } - ) - l_f.flush() + for dataloader_idx, dataloader_name in enumerate(dataloader_names): + logfile = os.path.join( + self.output_dir, + f"predictions_{dataloader_name}", + "predictions.csv", + ) + os.makedirs(os.path.dirname(logfile), exist_ok=True) + + logger.info(f"Saving predictions in {logfile}.") + + with open(logfile, "w") as l_f: + logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields) + logwriter.writeheader() + + for prediction in predictions[dataloader_idx]: + logwriter.writerow( + { + "filename": prediction[0], + "likelihood": prediction[1].numpy(), + "ground_truth": prediction[2].numpy(), + } + ) + l_f.flush() diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 3de96bdd77aea7390df6f87d14bd835861049d43..b5ed123c2da9df3b1c485b999ac5eb9bfa04c9ee 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -102,6 +102,8 @@ def predict( datamodule.set_chunk_size(batch_size, 1) datamodule.model_transforms = model.model_transforms + datamodule.prepare_data() + datamodule.setup(stage="predict") logger.info(f"Loading checkpoint from {weight}") model = model.load_from_checkpoint(weight, strict=False)