From aea0a2e24e86650a802ed120078bb44ea0d1edf5 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Fri, 21 Jul 2023 18:36:17 +0200
Subject: [PATCH] Predict on all available sets

---
 src/ptbench/data/datamodule.py  | 19 ++++++++++---
 src/ptbench/engine/callbacks.py | 47 +++++++++++++++++----------------
 src/ptbench/scripts/predict.py  |  2 ++
 3 files changed, 42 insertions(+), 26 deletions(-)

diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index bb2dcdda..af0d513a 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -686,7 +686,9 @@ class CachingDataModule(lightning.LightningDataModule):
     def _val_dataset_keys(self) -> list[str]:
         """Returns list of validation dataset names."""
         return ["validation"] + [
-            k for k in self.database_split.keys() if k.startswith("monitor-")
+            k
+            for k in self.database_split.subsets.keys()
+            if k.startswith("monitor-")
         ]
 
     def setup(self, stage: str) -> None:
@@ -727,7 +729,8 @@ class CachingDataModule(lightning.LightningDataModule):
             self._setup_dataset("test")
 
         elif stage == "predict":
-            self._setup_dataset("test")
+            for k in self.database_split.subsets.keys():
+                self._setup_dataset(k)
 
     def teardown(self, stage: str) -> None:
         """Unset-up datasets for different tasks on the pipeline.
@@ -814,4 +817,14 @@ class CachingDataModule(lightning.LightningDataModule):
     def predict_dataloader(self) -> dict[str, DataLoader]:
         """Returns the prediction data loader(s)"""
 
-        return self.test_dataloader()
+        return {
+            k: torch.utils.data.DataLoader(
+                self._datasets[k],
+                batch_size=self._chunk_size,
+                shuffle=False,
+                drop_last=self.drop_incomplete_batch,
+                pin_memory=self.pin_memory,
+                **self._dataloader_multiproc,
+            )
+            for k in self._datasets
+        }
diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index 031acaae..c2ce035f 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -402,27 +402,28 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
         predictions: typing.Sequence[typing.Any],
         batch_indices: typing.Sequence[typing.Any] | None,
     ) -> None:
-        dataloader_name = list(trainer.datamodule.predict_dataloader().keys())[
-            0
-        ]
+        dataloader_names = list(trainer.datamodule.predict_dataloader().keys())
 
-        logfile = os.path.join(
-            self.output_dir, f"predictions_{dataloader_name}", "predictions.csv"
-        )
-        os.makedirs(os.path.dirname(logfile), exist_ok=True)
-
-        logger.info(f"Saving predictions in {logfile}.")
-
-        with open(logfile, "w") as l_f:
-            logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
-            logwriter.writeheader()
-
-            for prediction in predictions:
-                logwriter.writerow(
-                    {
-                        "filename": prediction[0],
-                        "likelihood": prediction[1].numpy(),
-                        "ground_truth": prediction[2].numpy(),
-                    }
-                )
-            l_f.flush()
+        for dataloader_idx, dataloader_name in enumerate(dataloader_names):
+            logfile = os.path.join(
+                self.output_dir,
+                f"predictions_{dataloader_name}",
+                "predictions.csv",
+            )
+            os.makedirs(os.path.dirname(logfile), exist_ok=True)
+
+            logger.info(f"Saving predictions in {logfile}.")
+
+            with open(logfile, "w") as l_f:
+                logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
+                logwriter.writeheader()
+
+                for prediction in predictions[dataloader_idx]:
+                    logwriter.writerow(
+                        {
+                            "filename": prediction[0],
+                            "likelihood": prediction[1].numpy(),
+                            "ground_truth": prediction[2].numpy(),
+                        }
+                    )
+                l_f.flush()
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index 3de96bdd..b5ed123c 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -102,6 +102,8 @@ def predict(
     datamodule.set_chunk_size(batch_size, 1)
     datamodule.model_transforms = model.model_transforms
 
+    datamodule.prepare_data()
+    datamodule.setup(stage="predict")
 
     logger.info(f"Loading checkpoint from {weight}")
     model = model.load_from_checkpoint(weight, strict=False)
-- 
GitLab