diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index edae044b0a0df037751b8c86fa42f68ef9526fbc..92597797815bb51e159a821bdefa2439c9791de1 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -6,6 +6,7 @@ import logging import pathlib import lightning.pytorch +import torch.utils.data from .device import DeviceManager @@ -17,7 +18,7 @@ def run( datamodule: lightning.pytorch.LightningDataModule, device_manager: DeviceManager, output_folder: pathlib.Path, -) -> dict[str, list] | list | list[list] | None: +) -> list | list[list] | dict[str, list] | None: """Runs inference on input data, outputs csv files with predictions. Parameters @@ -30,18 +31,31 @@ def run( An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a torch lightning accelerator setup. - output_folder : str + output_folder Directory in which the logs will be saved. Returns ------- - predictions - A dictionary containing the predictions for each of the input samples - per dataloader. Keys correspond to the original split names defined at - the loader. If the datamodule's ``predict_dataloader()`` method does - not return a dictionary, then its output is directly passed to the - trainer ``predict()`` method. + Depending on the return type of the datamodule's + ``predict_dataloader()`` method: + + * if :py:class:`torch.utils.data.DataLoader`, then returns a + :py:class:`list` of predictions + * if :py:class:`list` of :py:class:`torch.utils.data.DataLoader`, then + returns a list of lists of predictions, each list corresponding to + the iteration over one of the dataloaders. + * if :py:class:`dict` of :py:class:`str` to + :py:class:`torch.utils.data.DataLoader`, then returns a dictionary + mapping names to lists of predictions + * if ``None``, then returns ``None`` + + + Raises + ------ + TypeError + If the datamodule's ``predict_dataloader()`` method does not return any + of the types described above. """ from .loggers import CustomTensorboardLogger @@ -64,16 +78,33 @@ def run( logger=tensorboard_logger, ) + def _flatten(p: list[list]): + return [sample for batch in p for sample in batch] + dataloaders = datamodule.predict_dataloader() - if isinstance(dataloaders, dict): - retval = {} + if isinstance(dataloaders, torch.utils.data.DataLoader): + logger.info("Running prediction on a single dataloader...") + return _flatten(trainer.predict(model, dataloaders)) # type: ignore + elif isinstance(dataloaders, list): + retval_list = [] + for k, dataloader in enumerate(dataloaders): + logger.info(f"Running prediction on split `{k}`...") + retval_list.append(_flatten(trainer.predict(model, dataloader))) # type: ignore + return retval_list + elif isinstance(dataloaders, dict): + retval_dict = {} for name, dataloader in dataloaders.items(): logger.info(f"Running prediction on `{name}` split...") - predictions = trainer.predict(model, dataloader) - retval[name] = [ - sample for batch in predictions for sample in batch # type: ignore - ] - return retval - - # just pass all the loaders to the trainer, let it handle - return trainer.predict(model, datamodule) + retval_dict[name] = _flatten(trainer.predict(model, dataloader)) # type: ignore + return retval_dict + elif dataloaders is None: + logger.warning("Datamodule did not return any prediction dataloaders!") + return None + + # if you get to this point, then the user is returning something that is + # not supported - complain! + raise TypeError( + f"Datamodule returned strangely typed prediction " + f"dataloaders: `{type(dataloaders)}` - Please write code " + f"to support this use-case." + ) diff --git a/src/ptbench/scripts/click.py b/src/ptbench/scripts/click.py new file mode 100644 index 0000000000000000000000000000000000000000..39cf96f962fced761bb68c12a593a24be19e9fa5 --- /dev/null +++ b/src/ptbench/scripts/click.py @@ -0,0 +1,30 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import click + +from clapper.click import ConfigCommand as _BaseConfigCommand + + +class ConfigCommand(_BaseConfigCommand): + """A click command-class that has the properties of + :py:class:`clapper.click.ConfigCommand` and adds verbatim epilog + formatting.""" + + def format_epilog( + self, _: click.core.Context, formatter: click.formatting.HelpFormatter + ) -> None: + """Formats the command epilog during --help. + + Arguments: + + _: The current parsing context + + formatter: The formatter to use for printing text + """ + + if self.epilog: + formatter.write_paragraph() + for line in self.epilog.split("\n"): + formatter.write_text(line) diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 551422fde176a9095ec3ac7dd3f69c6e174161aa..0715e423393570d4d63affafeec4c78eb0db28ba 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -6,9 +6,11 @@ import pathlib import click -from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup +from .click import ConfigCommand + logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -17,19 +19,17 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: - 1. Runs prediction on an existing datamodule configuration: +1. Runs prediction on an existing datamodule configuration: - .. code:: sh + .. code:: sh - \b - ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json + ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json - 2. Enables multi-processing data loading with 6 processes: +2. Enables multi-processing data loading with 6 processes: - .. code:: sh + .. code:: sh - \b - ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json + ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json """, ) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 8f2f51c923c176a2929a9c78baf1049b43ed53ee..11ac8e07b546a80e407ac79cc42184c37187a3e2 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -4,9 +4,11 @@ import click -from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup +from .click import ConfigCommand + logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -15,13 +17,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -\b - 1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``): - - .. code:: sh +1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``): - ptbench train -vv pasa montgomery --batch-size=4 --device="cuda:0" + .. code:: sh + ptbench train -vv pasa montgomery --batch-size=4 --device="cuda:0" """, ) @click.option(