From 50148b8cbb81d73eeaaed46eb212fee3f2517d81 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Tue, 15 Aug 2023 22:13:42 +0200 Subject: [PATCH] [doc] Documentation fixes --- src/ptbench/engine/predictor.py | 67 ++++++++++++++++++++++++--------- src/ptbench/scripts/click.py | 30 +++++++++++++++ src/ptbench/scripts/predict.py | 18 ++++----- src/ptbench/scripts/train.py | 12 +++--- 4 files changed, 94 insertions(+), 33 deletions(-) create mode 100644 src/ptbench/scripts/click.py diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index edae044b..92597797 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -6,6 +6,7 @@ import logging import pathlib import lightning.pytorch +import torch.utils.data from .device import DeviceManager @@ -17,7 +18,7 @@ def run( datamodule: lightning.pytorch.LightningDataModule, device_manager: DeviceManager, output_folder: pathlib.Path, -) -> dict[str, list] | list | list[list] | None: +) -> list | list[list] | dict[str, list] | None: """Runs inference on input data, outputs csv files with predictions. Parameters @@ -30,18 +31,31 @@ def run( An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a torch lightning accelerator setup. - output_folder : str + output_folder Directory in which the logs will be saved. Returns ------- - predictions - A dictionary containing the predictions for each of the input samples - per dataloader. Keys correspond to the original split names defined at - the loader. If the datamodule's ``predict_dataloader()`` method does - not return a dictionary, then its output is directly passed to the - trainer ``predict()`` method. + Depending on the return type of the datamodule's + ``predict_dataloader()`` method: + + * if :py:class:`torch.utils.data.DataLoader`, then returns a + :py:class:`list` of predictions + * if :py:class:`list` of :py:class:`torch.utils.data.DataLoader`, then + returns a list of lists of predictions, each list corresponding to + the iteration over one of the dataloaders. + * if :py:class:`dict` of :py:class:`str` to + :py:class:`torch.utils.data.DataLoader`, then returns a dictionary + mapping names to lists of predictions + * if ``None``, then returns ``None`` + + + Raises + ------ + TypeError + If the datamodule's ``predict_dataloader()`` method does not return any + of the types described above. """ from .loggers import CustomTensorboardLogger @@ -64,16 +78,33 @@ def run( logger=tensorboard_logger, ) + def _flatten(p: list[list]): + return [sample for batch in p for sample in batch] + dataloaders = datamodule.predict_dataloader() - if isinstance(dataloaders, dict): - retval = {} + if isinstance(dataloaders, torch.utils.data.DataLoader): + logger.info("Running prediction on a single dataloader...") + return _flatten(trainer.predict(model, dataloaders)) # type: ignore + elif isinstance(dataloaders, list): + retval_list = [] + for k, dataloader in enumerate(dataloaders): + logger.info(f"Running prediction on split `{k}`...") + retval_list.append(_flatten(trainer.predict(model, dataloader))) # type: ignore + return retval_list + elif isinstance(dataloaders, dict): + retval_dict = {} for name, dataloader in dataloaders.items(): logger.info(f"Running prediction on `{name}` split...") - predictions = trainer.predict(model, dataloader) - retval[name] = [ - sample for batch in predictions for sample in batch # type: ignore - ] - return retval - - # just pass all the loaders to the trainer, let it handle - return trainer.predict(model, datamodule) + retval_dict[name] = _flatten(trainer.predict(model, dataloader)) # type: ignore + return retval_dict + elif dataloaders is None: + logger.warning("Datamodule did not return any prediction dataloaders!") + return None + + # if you get to this point, then the user is returning something that is + # not supported - complain! + raise TypeError( + f"Datamodule returned strangely typed prediction " + f"dataloaders: `{type(dataloaders)}` - Please write code " + f"to support this use-case." + ) diff --git a/src/ptbench/scripts/click.py b/src/ptbench/scripts/click.py new file mode 100644 index 00000000..39cf96f9 --- /dev/null +++ b/src/ptbench/scripts/click.py @@ -0,0 +1,30 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import click + +from clapper.click import ConfigCommand as _BaseConfigCommand + + +class ConfigCommand(_BaseConfigCommand): + """A click command-class that has the properties of + :py:class:`clapper.click.ConfigCommand` and adds verbatim epilog + formatting.""" + + def format_epilog( + self, _: click.core.Context, formatter: click.formatting.HelpFormatter + ) -> None: + """Formats the command epilog during --help. + + Arguments: + + _: The current parsing context + + formatter: The formatter to use for printing text + """ + + if self.epilog: + formatter.write_paragraph() + for line in self.epilog.split("\n"): + formatter.write_text(line) diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 551422fd..0715e423 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -6,9 +6,11 @@ import pathlib import click -from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup +from .click import ConfigCommand + logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -17,19 +19,17 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: - 1. Runs prediction on an existing datamodule configuration: +1. Runs prediction on an existing datamodule configuration: - .. code:: sh + .. code:: sh - \b - ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json + ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json - 2. Enables multi-processing data loading with 6 processes: +2. Enables multi-processing data loading with 6 processes: - .. code:: sh + .. code:: sh - \b - ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json + ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json """, ) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 8f2f51c9..11ac8e07 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -4,9 +4,11 @@ import click -from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup +from .click import ConfigCommand + logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -15,13 +17,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -\b - 1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``): - - .. code:: sh +1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``): - ptbench train -vv pasa montgomery --batch-size=4 --device="cuda:0" + .. code:: sh + ptbench train -vv pasa montgomery --batch-size=4 --device="cuda:0" """, ) @click.option( -- GitLab