diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 033e0829e2b060ec0351feb7f63adb9e0f690c35..49e5f5acdbd1866e3cd10b323720d2eda3f2d0c2 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -3,11 +3,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later import logging -import pathlib import lightning.pytorch import torch.utils.data +from ..models.typing import Prediction from .device import DeviceManager logger = logging.getLogger(__name__) @@ -17,8 +17,12 @@ def run( model: lightning.pytorch.LightningModule, datamodule: lightning.pytorch.LightningDataModule, device_manager: DeviceManager, - output_folder: pathlib.Path, -) -> list | list[list] | dict[str, list] | None: +) -> ( + list[Prediction] + | list[list[Prediction]] + | dict[str, list[Prediction]] + | None +): """Runs inference on input data, outputs csv files with predictions. Parameters @@ -31,8 +35,6 @@ def run( An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a torch lightning accelerator setup. - output_folder - Directory in which the logs will be saved. Returns @@ -58,24 +60,13 @@ def run( of the types described above. """ - from .loggers import CustomTensorboardLogger - - log_dir = "logs" - tensorboard_logger = CustomTensorboardLogger( - output_folder, - log_dir, - ) - # logger.info( - # f"Monitor prediction with `tensorboard serve " - # f"--logdir={output_folder}/{log_dir}/`. " - # f"Then, open a browser on the printed address." - # ) + from lightning.pytorch.loggers.logger import DummyLogger accelerator, devices = device_manager.lightning_accelerator() trainer = lightning.pytorch.Trainer( accelerator=accelerator, devices=devices, - logger=tensorboard_logger, + logger=DummyLogger(), ) def _flatten(p: list[list]): diff --git a/src/ptbench/models/separate.py b/src/ptbench/models/separate.py index 522d5b94d2313a34c23cb6322c1d986dc3631217..febb0728e3db919b119ca76402dccc516ba41306 100644 --- a/src/ptbench/models/separate.py +++ b/src/ptbench/models/separate.py @@ -3,12 +3,32 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Contains the inverse :py:func:`torch.utils.data.default_collate`.""" +import typing + import torch from ..data.typing import Sample +from .typing import Prediction + + +def _as_predictions(samples: typing.Iterable[Sample]) -> list[Prediction]: + """Takes a list of separated batch predictions and transform into a list of + formal predictions. + Parameters + ---------- + samples + A sequence of samples as returned by :py:func:`separate`. + + + Returns + ------- + A list of typed predictions that can be saved to disk. + """ + return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples] -def separate(batch: Sample) -> list[Sample]: + +def separate(batch: Sample) -> list[Prediction]: """Separates a collated batch reconstituting its samples. This function implements the inverse of @@ -19,6 +39,18 @@ def separate(batch: Sample) -> list[Sample]: * :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer dimension, via :py:func:`torch.flatten`) * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]`` + + + Parameters + ---------- + batch + A batch, as output by torch model forwarding + + + Returns + ------- + A list of predictions that contains the predictions and associated metadata + for each processed sample. """ # as of now, this is really simple - to be made more complex upon need. @@ -26,4 +58,4 @@ def separate(batch: Sample) -> list[Sample]: {key: value[i] for key, value in batch[1].items()} for i in range(len(batch[0])) ] - return list(zip(torch.flatten(batch[0]), metadata)) + return _as_predictions(zip(torch.flatten(batch[0]), metadata)) diff --git a/src/ptbench/models/typing.py b/src/ptbench/models/typing.py index 76c02c6d81b602472e3126a6698d2028b83dba53..3294cafdddc0d2415b787f8c1eb494ec31bdfdd6 100644 --- a/src/ptbench/models/typing.py +++ b/src/ptbench/models/typing.py @@ -5,5 +5,17 @@ import typing -Checkpoint: typing.TypeAlias = dict[str, typing.Any] +Checkpoint: typing.TypeAlias = typing.Mapping[str, typing.Any] """Definition of a lightning checkpoint.""" + + +Prediction: typing.TypeAlias = tuple[ + str, int | typing.Sequence[int], float | typing.Sequence[float] +] +"""Prediction: the sample name, the target, and the predicted value.""" + + +PredictionSplit: typing.TypeAlias = typing.Mapping[ + str, typing.Sequence[Prediction] +] +"""A series of predictions for different database splits.""" diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index c13a4f242b783335f85072d88a6e951fb988a398..0b54d05146cbbb6fceffd5e58f6928e0580cb1d5 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -135,7 +135,7 @@ def predict( logger.info(f"Loading checkpoint from `{weight}`...") model = model.load_from_checkpoint(weight, strict=False) - predictions = run(model, datamodule, DeviceManager(device), output.parent) + predictions = run(model, datamodule, DeviceManager(device)) output.parent.mkdir(parents=True, exist_ok=True) if output.exists(): @@ -147,16 +147,5 @@ def predict( shutil.copy(output, backup) with output.open("w") as f: - flat_predictions: dict[str, list[list]] = {} - # creates a flat representation of predictions that is similar to our - # own JSON split files - for split_name, split_values in predictions.items(): # type: ignore - flat_predictions.setdefault( - split_name, - [ - [v[1]["name"], v[1]["label"].item(), v[0].item()] - for v in split_values - ], - ) - json.dump(flat_predictions, f, indent=2) + json.dump(predictions, f, indent=2) logger.info(f"Predictions saved to `{str(output)}`")