# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import pathlib import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from .click import ConfigCommand logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.command( entry_point_group="ptbench.config", cls=ConfigCommand, epilog="""Examples: 1. Runs prediction on an existing datamodule configuration: .. code:: sh ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json 2. Enables multi-processing data loading with 6 processes: .. code:: sh ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json """, ) @click.option( "--output", "-o", help="""Path where to store the JSON predictions for all samples in the input datamodule (leading directories are created if they do not not exist).""", required=True, default="results", cls=ResourceOption, type=click.Path( file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path ), ) @click.option( "--model", "-m", help="""A lightining module instance implementing the network architecture (not the weights, necessarily) to be used for prediction.""", required=True, cls=ResourceOption, ) @click.option( "--datamodule", "-d", help="""A lighting data module that will be asked for prediction data loaders. Typically, this includes all configured splits in a datamodule, however this is not a requirement. A datamodule that returns a single dataloader for prediction (wrapped in a dictionary) is acceptable.""", required=True, cls=ResourceOption, ) @click.option( "--batch-size", "-b", help="""Number of samples in every batch (this parameter affects memory requirements for the network).""", required=True, show_default=True, default=1, type=click.IntRange(min=1), cls=ResourceOption, ) @click.option( "--device", "-d", help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', show_default=True, required=True, default="cpu", cls=ResourceOption, ) @click.option( "--weight", "-w", help="""Path or URL to pretrained model file (`.ckpt` extension), corresponding to the architecture set with `--model`.""", required=True, cls=ResourceOption, type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), ) @click.option( "--parallel", "-P", help="""Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores as available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading.""", type=click.IntRange(min=-1), show_default=True, required=True, default=-1, cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def predict( output, model, datamodule, batch_size, device, weight, parallel, **_, ) -> None: """Runs inference (generates scores) on all input images, using a pre- trained model.""" import json import shutil from ..engine.device import DeviceManager from ..engine.predictor import run datamodule.set_chunk_size(batch_size, 1) datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms datamodule.prepare_data() datamodule.setup(stage="predict") logger.info(f"Loading checkpoint from `{weight}`...") model = type(model).load_from_checkpoint(weight, strict=False) predictions = run(model, datamodule, DeviceManager(device)) output.parent.mkdir(parents=True, exist_ok=True) if output.exists(): backup = output.parent / (output.name + "~") logger.warning( f"Output predictions file `{str(output)}` exists - " f"backing it up to `{str(backup)}`..." ) shutil.copy(output, backup) with output.open("w") as f: json.dump(predictions, f, indent=2) logger.info(f"Predictions saved to `{str(output)}`")