Skip to content
Snippets Groups Projects
Commit 48990aea authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.predictor] Streamline typing around prediction

parent b0a1d2eb
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Pipeline #77116 failed
......@@ -3,11 +3,11 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import pathlib
import lightning.pytorch
import torch.utils.data
from ..models.typing import Prediction
from .device import DeviceManager
logger = logging.getLogger(__name__)
......@@ -17,8 +17,12 @@ def run(
model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager,
output_folder: pathlib.Path,
) -> list | list[list] | dict[str, list] | None:
) -> (
list[Prediction]
| list[list[Prediction]]
| dict[str, list[Prediction]]
| None
):
"""Runs inference on input data, outputs csv files with predictions.
Parameters
......@@ -31,8 +35,6 @@ def run(
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a torch lightning accelerator setup.
output_folder
Directory in which the logs will be saved.
Returns
......@@ -58,24 +60,13 @@ def run(
of the types described above.
"""
from .loggers import CustomTensorboardLogger
log_dir = "logs"
tensorboard_logger = CustomTensorboardLogger(
output_folder,
log_dir,
)
# logger.info(
# f"Monitor prediction with `tensorboard serve "
# f"--logdir={output_folder}/{log_dir}/`. "
# f"Then, open a browser on the printed address."
# )
from lightning.pytorch.loggers.logger import DummyLogger
accelerator, devices = device_manager.lightning_accelerator()
trainer = lightning.pytorch.Trainer(
accelerator=accelerator,
devices=devices,
logger=tensorboard_logger,
logger=DummyLogger(),
)
def _flatten(p: list[list]):
......
......@@ -3,12 +3,32 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""Contains the inverse :py:func:`torch.utils.data.default_collate`."""
import typing
import torch
from ..data.typing import Sample
from .typing import Prediction
def _as_predictions(samples: typing.Iterable[Sample]) -> list[Prediction]:
"""Takes a list of separated batch predictions and transform into a list of
formal predictions.
Parameters
----------
samples
A sequence of samples as returned by :py:func:`separate`.
Returns
-------
A list of typed predictions that can be saved to disk.
"""
return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples]
def separate(batch: Sample) -> list[Sample]:
def separate(batch: Sample) -> list[Prediction]:
"""Separates a collated batch reconstituting its samples.
This function implements the inverse of
......@@ -19,6 +39,18 @@ def separate(batch: Sample) -> list[Sample]:
* :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer
dimension, via :py:func:`torch.flatten`)
* ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]``
Parameters
----------
batch
A batch, as output by torch model forwarding
Returns
-------
A list of predictions that contains the predictions and associated metadata
for each processed sample.
"""
# as of now, this is really simple - to be made more complex upon need.
......@@ -26,4 +58,4 @@ def separate(batch: Sample) -> list[Sample]:
{key: value[i] for key, value in batch[1].items()}
for i in range(len(batch[0]))
]
return list(zip(torch.flatten(batch[0]), metadata))
return _as_predictions(zip(torch.flatten(batch[0]), metadata))
......@@ -5,5 +5,17 @@
import typing
Checkpoint: typing.TypeAlias = dict[str, typing.Any]
Checkpoint: typing.TypeAlias = typing.Mapping[str, typing.Any]
"""Definition of a lightning checkpoint."""
Prediction: typing.TypeAlias = tuple[
str, int | typing.Sequence[int], float | typing.Sequence[float]
]
"""Prediction: the sample name, the target, and the predicted value."""
PredictionSplit: typing.TypeAlias = typing.Mapping[
str, typing.Sequence[Prediction]
]
"""A series of predictions for different database splits."""
......@@ -135,7 +135,7 @@ def predict(
logger.info(f"Loading checkpoint from `{weight}`...")
model = model.load_from_checkpoint(weight, strict=False)
predictions = run(model, datamodule, DeviceManager(device), output.parent)
predictions = run(model, datamodule, DeviceManager(device))
output.parent.mkdir(parents=True, exist_ok=True)
if output.exists():
......@@ -147,16 +147,5 @@ def predict(
shutil.copy(output, backup)
with output.open("w") as f:
flat_predictions: dict[str, list[list]] = {}
# creates a flat representation of predictions that is similar to our
# own JSON split files
for split_name, split_values in predictions.items(): # type: ignore
flat_predictions.setdefault(
split_name,
[
[v[1]["name"], v[1]["label"].item(), v[0].item()]
for v in split_values
],
)
json.dump(flat_predictions, f, indent=2)
json.dump(predictions, f, indent=2)
logger.info(f"Predictions saved to `{str(output)}`")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment