# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.command( entry_point_group="ptbench.config", cls=ConfigCommand, epilog="""Examples: \b 1. Runs prediction on an existing dataset configuration: .. code:: sh ptbench predict -vv pasa montgomery --weight=path/to/model_final.pth --output-folder=path/to/predictions """, ) @click.option( "--output-folder", "-o", help="Path where to store the predictions (created if does not exist)", required=True, default="results", cls=ResourceOption, type=click.Path(), ) @click.option( "--model", "-m", help="A torch.nn.Module instance implementing the network to be evaluated", required=True, cls=ResourceOption, ) @click.option( "--datamodule", "-d", help="A torch.utils.data.dataset.Dataset instance implementing a dataset " "to be used for running prediction, possibly including all pre-processing " "pipelines required or, optionally, a dictionary mapping string keys to " "torch.utils.data.dataset.Dataset instances. All keys that do not start " "with an underscore (_) will be processed.", required=True, cls=ResourceOption, ) @click.option( "--batch-size", "-b", help="Number of samples in every batch (this parameter affects memory requirements for the network)", required=True, show_default=True, default=1, type=click.IntRange(min=1), cls=ResourceOption, ) @click.option( "--device", "-d", help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', show_default=True, required=True, default="cpu", cls=ResourceOption, ) @click.option( "--weight", "-w", help="Path or URL to pretrained model file (.ckpt extension)", required=True, cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def predict( output_folder, model, datamodule, batch_size, device, weight, **_, ) -> None: """Predicts Tuberculosis presence (probabilities) on input images.""" import os import numpy as np from matplotlib.backends.backend_pdf import PdfPages from ..engine.device import DeviceManager from ..engine.predictor import run from ..utils.plot import relevance_analysis_plot datamodule.set_chunk_size(batch_size, 1) datamodule.model_transforms = model.model_transforms datamodule.prepare_data() datamodule.setup(stage="predict") logger.info(f"Loading checkpoint from {weight}") model = model.load_from_checkpoint(weight, strict=False) # Logistic regressor weights if model.name == "logistic_regression": logger.info("Logistic regression identified: saving model weights") for param in model.parameters(): model_weights = np.array(param.data.reshape(-1)) break filepath = os.path.join(output_folder, "LogReg_Weights.pdf") logger.info(f"Creating and saving weights plot at {filepath}...") os.makedirs(os.path.dirname(filepath), exist_ok=True) pdf = PdfPages(filepath) pdf.savefig( relevance_analysis_plot(model_weights, title="LogReg model weights") ) pdf.close() _ = run(model, datamodule, DeviceManager(device), output_folder)