From 50148b8cbb81d73eeaaed46eb212fee3f2517d81 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 15 Aug 2023 22:13:42 +0200
Subject: [PATCH] [doc] Documentation fixes

---
 src/ptbench/engine/predictor.py | 67 ++++++++++++++++++++++++---------
 src/ptbench/scripts/click.py    | 30 +++++++++++++++
 src/ptbench/scripts/predict.py  | 18 ++++-----
 src/ptbench/scripts/train.py    | 12 +++---
 4 files changed, 94 insertions(+), 33 deletions(-)
 create mode 100644 src/ptbench/scripts/click.py

diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py
index edae044b..92597797 100644
--- a/src/ptbench/engine/predictor.py
+++ b/src/ptbench/engine/predictor.py
@@ -6,6 +6,7 @@ import logging
 import pathlib
 
 import lightning.pytorch
+import torch.utils.data
 
 from .device import DeviceManager
 
@@ -17,7 +18,7 @@ def run(
     datamodule: lightning.pytorch.LightningDataModule,
     device_manager: DeviceManager,
     output_folder: pathlib.Path,
-) -> dict[str, list] | list | list[list] | None:
+) -> list | list[list] | dict[str, list] | None:
     """Runs inference on input data, outputs csv files with predictions.
 
     Parameters
@@ -30,18 +31,31 @@ def run(
         An internal device representation, to be used for training and
         validation.  This representation can be converted into a pytorch device
         or a torch lightning accelerator setup.
-    output_folder : str
+    output_folder
         Directory in which the logs will be saved.
 
 
     Returns
     -------
-    predictions
-        A dictionary containing the predictions for each of the input samples
-        per dataloader.  Keys correspond to the original split names defined at
-        the loader.  If the datamodule's ``predict_dataloader()`` method does
-        not return a dictionary, then its output is directly passed to the
-        trainer ``predict()`` method.
+        Depending on the return type of the datamodule's
+        ``predict_dataloader()`` method:
+
+        * if :py:class:`torch.utils.data.DataLoader`, then returns a
+          :py:class:`list` of predictions
+        * if :py:class:`list` of :py:class:`torch.utils.data.DataLoader`, then
+          returns a list of lists of predictions, each list corresponding to
+          the iteration over one of the dataloaders.
+        * if :py:class:`dict` of :py:class:`str` to
+          :py:class:`torch.utils.data.DataLoader`, then returns a dictionary
+          mapping names to lists of predictions
+        * if ``None``, then returns ``None``
+
+
+    Raises
+    ------
+    TypeError
+        If the datamodule's ``predict_dataloader()`` method does not return any
+        of the types described above.
     """
 
     from .loggers import CustomTensorboardLogger
@@ -64,16 +78,33 @@ def run(
         logger=tensorboard_logger,
     )
 
+    def _flatten(p: list[list]):
+        return [sample for batch in p for sample in batch]
+
     dataloaders = datamodule.predict_dataloader()
-    if isinstance(dataloaders, dict):
-        retval = {}
+    if isinstance(dataloaders, torch.utils.data.DataLoader):
+        logger.info("Running prediction on a single dataloader...")
+        return _flatten(trainer.predict(model, dataloaders))  # type: ignore
+    elif isinstance(dataloaders, list):
+        retval_list = []
+        for k, dataloader in enumerate(dataloaders):
+            logger.info(f"Running prediction on split `{k}`...")
+            retval_list.append(_flatten(trainer.predict(model, dataloader)))  # type: ignore
+        return retval_list
+    elif isinstance(dataloaders, dict):
+        retval_dict = {}
         for name, dataloader in dataloaders.items():
             logger.info(f"Running prediction on `{name}` split...")
-            predictions = trainer.predict(model, dataloader)
-            retval[name] = [
-                sample for batch in predictions for sample in batch  # type: ignore
-            ]
-        return retval
-
-    # just pass all the loaders to the trainer, let it handle
-    return trainer.predict(model, datamodule)
+            retval_dict[name] = _flatten(trainer.predict(model, dataloader))  # type: ignore
+        return retval_dict
+    elif dataloaders is None:
+        logger.warning("Datamodule did not return any prediction dataloaders!")
+        return None
+
+    # if you get to this point, then the user is returning something that is
+    # not supported - complain!
+    raise TypeError(
+        f"Datamodule returned strangely typed prediction "
+        f"dataloaders: `{type(dataloaders)}` - Please write code "
+        f"to support this use-case."
+    )
diff --git a/src/ptbench/scripts/click.py b/src/ptbench/scripts/click.py
new file mode 100644
index 00000000..39cf96f9
--- /dev/null
+++ b/src/ptbench/scripts/click.py
@@ -0,0 +1,30 @@
+# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import click
+
+from clapper.click import ConfigCommand as _BaseConfigCommand
+
+
+class ConfigCommand(_BaseConfigCommand):
+    """A click command-class that has the properties of
+    :py:class:`clapper.click.ConfigCommand` and adds verbatim epilog
+    formatting."""
+
+    def format_epilog(
+        self, _: click.core.Context, formatter: click.formatting.HelpFormatter
+    ) -> None:
+        """Formats the command epilog during --help.
+
+        Arguments:
+
+            _: The current parsing context
+
+            formatter: The formatter to use for printing text
+        """
+
+        if self.epilog:
+            formatter.write_paragraph()
+            for line in self.epilog.split("\n"):
+                formatter.write_text(line)
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index 551422fd..0715e423 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -6,9 +6,11 @@ import pathlib
 
 import click
 
-from clapper.click import ConfigCommand, ResourceOption, verbosity_option
+from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 
+from .click import ConfigCommand
+
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
@@ -17,19 +19,17 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ConfigCommand,
     epilog="""Examples:
 
-    1. Runs prediction on an existing datamodule configuration:
+1. Runs prediction on an existing datamodule configuration:
 
-       .. code:: sh
+   .. code:: sh
 
-          \b
-          ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json
+      ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json
 
-    2. Enables multi-processing data loading with 6 processes:
+2. Enables multi-processing data loading with 6 processes:
 
-       .. code:: sh
+   .. code:: sh
 
-          \b
-          ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json
+      ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json
 
 """,
 )
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 8f2f51c9..11ac8e07 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -4,9 +4,11 @@
 
 import click
 
-from clapper.click import ConfigCommand, ResourceOption, verbosity_option
+from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 
+from .click import ConfigCommand
+
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
@@ -15,13 +17,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ConfigCommand,
     epilog="""Examples:
 
-\b
-    1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``):
-
-       .. code:: sh
+1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``):
 
-          ptbench train -vv pasa montgomery --batch-size=4 --device="cuda:0"
+   .. code:: sh
 
+      ptbench train -vv pasa montgomery --batch-size=4 --device="cuda:0"
 """,
 )
 @click.option(
-- 
GitLab