From 48990aea35326cde152f3dcdb41501bd8ddf0081 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 16 Aug 2023 23:06:19 +0200
Subject: [PATCH] [engine.predictor] Streamline typing around prediction

---
 src/ptbench/engine/predictor.py | 27 +++++++++----------------
 src/ptbench/models/separate.py  | 36 +++++++++++++++++++++++++++++++--
 src/ptbench/models/typing.py    | 14 ++++++++++++-
 src/ptbench/scripts/predict.py  | 15 ++------------
 4 files changed, 58 insertions(+), 34 deletions(-)

diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py
index 033e0829..49e5f5ac 100644
--- a/src/ptbench/engine/predictor.py
+++ b/src/ptbench/engine/predictor.py
@@ -3,11 +3,11 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import logging
-import pathlib
 
 import lightning.pytorch
 import torch.utils.data
 
+from ..models.typing import Prediction
 from .device import DeviceManager
 
 logger = logging.getLogger(__name__)
@@ -17,8 +17,12 @@ def run(
     model: lightning.pytorch.LightningModule,
     datamodule: lightning.pytorch.LightningDataModule,
     device_manager: DeviceManager,
-    output_folder: pathlib.Path,
-) -> list | list[list] | dict[str, list] | None:
+) -> (
+    list[Prediction]
+    | list[list[Prediction]]
+    | dict[str, list[Prediction]]
+    | None
+):
     """Runs inference on input data, outputs csv files with predictions.
 
     Parameters
@@ -31,8 +35,6 @@ def run(
         An internal device representation, to be used for training and
         validation.  This representation can be converted into a pytorch device
         or a torch lightning accelerator setup.
-    output_folder
-        Directory in which the logs will be saved.
 
 
     Returns
@@ -58,24 +60,13 @@ def run(
         of the types described above.
     """
 
-    from .loggers import CustomTensorboardLogger
-
-    log_dir = "logs"
-    tensorboard_logger = CustomTensorboardLogger(
-        output_folder,
-        log_dir,
-    )
-    # logger.info(
-    #     f"Monitor prediction with `tensorboard serve "
-    #     f"--logdir={output_folder}/{log_dir}/`. "
-    #     f"Then, open a browser on the printed address."
-    # )
+    from lightning.pytorch.loggers.logger import DummyLogger
 
     accelerator, devices = device_manager.lightning_accelerator()
     trainer = lightning.pytorch.Trainer(
         accelerator=accelerator,
         devices=devices,
-        logger=tensorboard_logger,
+        logger=DummyLogger(),
     )
 
     def _flatten(p: list[list]):
diff --git a/src/ptbench/models/separate.py b/src/ptbench/models/separate.py
index 522d5b94..febb0728 100644
--- a/src/ptbench/models/separate.py
+++ b/src/ptbench/models/separate.py
@@ -3,12 +3,32 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 """Contains the inverse :py:func:`torch.utils.data.default_collate`."""
 
+import typing
+
 import torch
 
 from ..data.typing import Sample
+from .typing import Prediction
+
+
+def _as_predictions(samples: typing.Iterable[Sample]) -> list[Prediction]:
+    """Takes a list of separated batch predictions and transform into a list of
+    formal predictions.
 
+    Parameters
+    ----------
+    samples
+        A sequence of samples as returned by :py:func:`separate`.
+
+
+    Returns
+    -------
+        A list of typed predictions that can be saved to disk.
+    """
+    return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples]
 
-def separate(batch: Sample) -> list[Sample]:
+
+def separate(batch: Sample) -> list[Prediction]:
     """Separates a collated batch reconstituting its samples.
 
     This function implements the inverse of
@@ -19,6 +39,18 @@ def separate(batch: Sample) -> list[Sample]:
     * :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer
       dimension, via :py:func:`torch.flatten`)
     * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]``
+
+
+    Parameters
+    ----------
+    batch
+        A batch, as output by torch model forwarding
+
+
+    Returns
+    -------
+        A list of predictions that contains the predictions and associated metadata
+        for each processed sample.
     """
 
     # as of now, this is really simple - to be made more complex upon need.
@@ -26,4 +58,4 @@ def separate(batch: Sample) -> list[Sample]:
         {key: value[i] for key, value in batch[1].items()}
         for i in range(len(batch[0]))
     ]
-    return list(zip(torch.flatten(batch[0]), metadata))
+    return _as_predictions(zip(torch.flatten(batch[0]), metadata))
diff --git a/src/ptbench/models/typing.py b/src/ptbench/models/typing.py
index 76c02c6d..3294cafd 100644
--- a/src/ptbench/models/typing.py
+++ b/src/ptbench/models/typing.py
@@ -5,5 +5,17 @@
 
 import typing
 
-Checkpoint: typing.TypeAlias = dict[str, typing.Any]
+Checkpoint: typing.TypeAlias = typing.Mapping[str, typing.Any]
 """Definition of a lightning checkpoint."""
+
+
+Prediction: typing.TypeAlias = tuple[
+    str, int | typing.Sequence[int], float | typing.Sequence[float]
+]
+"""Prediction: the sample name, the target, and the predicted value."""
+
+
+PredictionSplit: typing.TypeAlias = typing.Mapping[
+    str, typing.Sequence[Prediction]
+]
+"""A series of predictions for different database splits."""
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index c13a4f24..0b54d051 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -135,7 +135,7 @@ def predict(
     logger.info(f"Loading checkpoint from `{weight}`...")
     model = model.load_from_checkpoint(weight, strict=False)
 
-    predictions = run(model, datamodule, DeviceManager(device), output.parent)
+    predictions = run(model, datamodule, DeviceManager(device))
 
     output.parent.mkdir(parents=True, exist_ok=True)
     if output.exists():
@@ -147,16 +147,5 @@ def predict(
         shutil.copy(output, backup)
 
     with output.open("w") as f:
-        flat_predictions: dict[str, list[list]] = {}
-        # creates a flat representation of predictions that is similar to our
-        # own JSON split files
-        for split_name, split_values in predictions.items():  # type: ignore
-            flat_predictions.setdefault(
-                split_name,
-                [
-                    [v[1]["name"], v[1]["label"].item(), v[0].item()]
-                    for v in split_values
-                ],
-            )
-        json.dump(flat_predictions, f, indent=2)
+        json.dump(predictions, f, indent=2)
     logger.info(f"Predictions saved to `{str(output)}`")
-- 
GitLab