Skip to content
Snippets Groups Projects
Commit 50148b8c authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[doc] Documentation fixes

parent 198c54cb
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -6,6 +6,7 @@ import logging
import pathlib
import lightning.pytorch
import torch.utils.data
from .device import DeviceManager
......@@ -17,7 +18,7 @@ def run(
datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager,
output_folder: pathlib.Path,
) -> dict[str, list] | list | list[list] | None:
) -> list | list[list] | dict[str, list] | None:
"""Runs inference on input data, outputs csv files with predictions.
Parameters
......@@ -30,18 +31,31 @@ def run(
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a torch lightning accelerator setup.
output_folder : str
output_folder
Directory in which the logs will be saved.
Returns
-------
predictions
A dictionary containing the predictions for each of the input samples
per dataloader. Keys correspond to the original split names defined at
the loader. If the datamodule's ``predict_dataloader()`` method does
not return a dictionary, then its output is directly passed to the
trainer ``predict()`` method.
Depending on the return type of the datamodule's
``predict_dataloader()`` method:
* if :py:class:`torch.utils.data.DataLoader`, then returns a
:py:class:`list` of predictions
* if :py:class:`list` of :py:class:`torch.utils.data.DataLoader`, then
returns a list of lists of predictions, each list corresponding to
the iteration over one of the dataloaders.
* if :py:class:`dict` of :py:class:`str` to
:py:class:`torch.utils.data.DataLoader`, then returns a dictionary
mapping names to lists of predictions
* if ``None``, then returns ``None``
Raises
------
TypeError
If the datamodule's ``predict_dataloader()`` method does not return any
of the types described above.
"""
from .loggers import CustomTensorboardLogger
......@@ -64,16 +78,33 @@ def run(
logger=tensorboard_logger,
)
def _flatten(p: list[list]):
return [sample for batch in p for sample in batch]
dataloaders = datamodule.predict_dataloader()
if isinstance(dataloaders, dict):
retval = {}
if isinstance(dataloaders, torch.utils.data.DataLoader):
logger.info("Running prediction on a single dataloader...")
return _flatten(trainer.predict(model, dataloaders)) # type: ignore
elif isinstance(dataloaders, list):
retval_list = []
for k, dataloader in enumerate(dataloaders):
logger.info(f"Running prediction on split `{k}`...")
retval_list.append(_flatten(trainer.predict(model, dataloader))) # type: ignore
return retval_list
elif isinstance(dataloaders, dict):
retval_dict = {}
for name, dataloader in dataloaders.items():
logger.info(f"Running prediction on `{name}` split...")
predictions = trainer.predict(model, dataloader)
retval[name] = [
sample for batch in predictions for sample in batch # type: ignore
]
return retval
# just pass all the loaders to the trainer, let it handle
return trainer.predict(model, datamodule)
retval_dict[name] = _flatten(trainer.predict(model, dataloader)) # type: ignore
return retval_dict
elif dataloaders is None:
logger.warning("Datamodule did not return any prediction dataloaders!")
return None
# if you get to this point, then the user is returning something that is
# not supported - complain!
raise TypeError(
f"Datamodule returned strangely typed prediction "
f"dataloaders: `{type(dataloaders)}` - Please write code "
f"to support this use-case."
)
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import click
from clapper.click import ConfigCommand as _BaseConfigCommand
class ConfigCommand(_BaseConfigCommand):
"""A click command-class that has the properties of
:py:class:`clapper.click.ConfigCommand` and adds verbatim epilog
formatting."""
def format_epilog(
self, _: click.core.Context, formatter: click.formatting.HelpFormatter
) -> None:
"""Formats the command epilog during --help.
Arguments:
_: The current parsing context
formatter: The formatter to use for printing text
"""
if self.epilog:
formatter.write_paragraph()
for line in self.epilog.split("\n"):
formatter.write_text(line)
......@@ -6,9 +6,11 @@ import pathlib
import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from .click import ConfigCommand
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -17,19 +19,17 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ConfigCommand,
epilog="""Examples:
1. Runs prediction on an existing datamodule configuration:
1. Runs prediction on an existing datamodule configuration:
.. code:: sh
.. code:: sh
\b
ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json
ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json
2. Enables multi-processing data loading with 6 processes:
2. Enables multi-processing data loading with 6 processes:
.. code:: sh
.. code:: sh
\b
ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json
ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json
""",
)
......
......@@ -4,9 +4,11 @@
import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from .click import ConfigCommand
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -15,13 +17,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ConfigCommand,
epilog="""Examples:
\b
1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``):
.. code:: sh
1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``):
ptbench train -vv pasa montgomery --batch-size=4 --device="cuda:0"
.. code:: sh
ptbench train -vv pasa montgomery --batch-size=4 --device="cuda:0"
""",
)
@click.option(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment