diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py index 9f27124bac0f9516f9f512db6d779be3edbb88a6..3b27fe7819359c97f5cd1ab8d6a9e3f8fcbf09d7 100644 --- a/src/ptbench/scripts/cli.py +++ b/src/ptbench/scripts/cli.py @@ -7,7 +7,6 @@ import click from clapper.click import AliasedGroup from . import ( - compare, config, database, evaluate, @@ -27,7 +26,6 @@ def cli(): pass -cli.add_command(compare.compare) cli.add_command(config.config) cli.add_command(database.database) cli.add_command(evaluate.evaluate) diff --git a/src/ptbench/scripts/compare.py b/src/ptbench/scripts/compare.py deleted file mode 100644 index 6bc70ccc5832a6693a15d238c636e01a0a77ab0e..0000000000000000000000000000000000000000 --- a/src/ptbench/scripts/compare.py +++ /dev/null @@ -1,199 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import click - -from clapper.click import verbosity_option -from clapper.logging import setup - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -def _validate_threshold(t, dataset): - """Validates the user threshold selection. - - Returns parsed threshold. - """ - if t is None: - return t - - # we try to convert it to float first - t = float(t) - if t < 0.0 or t > 1.0: - raise ValueError("Thresholds must be within range [0.0, 1.0]") - - return t - - -def _load(data, threshold): - """Plots comparison chart of all evaluated models. - - Parameters - ---------- - - data : dict - A dict in which keys are the names of the systems and the values are - paths to ``predictions.csv`` style files. - - threshold : :py:class:`float` - A threshold for the final classification. - - - Returns - ------- - - data : dict - A dict in which keys are the names of the systems and the values are - dictionaries that contain two keys: - - * ``df``: A :py:class:`pandas.DataFrame` with the predictions data - loaded to - * ``threshold``: The ``threshold`` parameter set on the input - """ - import re - - import pandas - import torch - - use_threshold = threshold - logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'") - - # loads all data - retval = {} - for name, predictions_path in data.items(): - # Load predictions - logger.info(f"Loading predictions from {predictions_path}...") - pred_data = pandas.read_csv(predictions_path) - pred = ( - torch.Tensor( - [ - eval( - re.sub(" +", " ", x.replace("\n", "")).replace(" ", ",") - ) - if isinstance(x, str) - else x - for x in pred_data["likelihood"].values - ] - ) - .double() - .flatten() - ) - gt = ( - torch.Tensor( - [ - eval( - re.sub(" +", " ", x.replace("\n", "")).replace(" ", ",") - ) - if isinstance(x, str) - else x - for x in pred_data["ground_truth"].values - ] - ) - .double() - .flatten() - ) - - pred_data["likelihood"] = pred - pred_data["ground_truth"] = gt - - retval[name] = dict(df=pred_data, threshold=use_threshold) - - return retval - - -@click.command( - epilog="""Examples: - -\b - 1. Compares system A and B, with their own predictions files: - - .. code:: sh - - ptbench compare -vv A path/to/A/predictions.csv B path/to/B/predictions.csv -""", -) -@click.argument( - "label_path", - nargs=-1, -) -@click.option( - "--output-figure", - "-f", - help="Path where write the output figure (any extension supported by " - "matplotlib is possible). If not provided, does not produce a figure.", - required=False, - default=None, - type=click.Path(dir_okay=False, file_okay=True), -) -@click.option( - "--table-format", - "-T", - help="The format to use for the comparison table", - show_default=True, - required=True, - default="rst", - type=click.Choice(__import__("tabulate").tabulate_formats), -) -@click.option( - "--output-table", - "-u", - help="Path where write the output table. If not provided, does not write " - "write a table to file, only to stdout.", - required=False, - default=None, - type=click.Path(dir_okay=False, file_okay=True), -) -@click.option( - "--threshold", - "-t", - help="This number is used to separate positive and negative cases " - "by thresholding their score.", - default=None, - show_default=False, - required=False, -) -@verbosity_option(logger=logger, expose_value=False) -def compare( - label_path, output_figure, table_format, output_table, threshold -) -> None: - """Compares multiple systems together.""" - - import os - - from matplotlib.backends.backend_pdf import PdfPages - - from ..utils.plot import precision_recall_f1iso, roc_curve - from ..utils.table import performance_table - - # hack to get a dictionary from arguments passed to input - if len(label_path) % 2 != 0: - raise click.ClickException( - "Input label-paths should be doubles" - " composed of name-path entries" - ) - data = dict(zip(label_path[::2], label_path[1::2])) - - threshold = _validate_threshold(threshold, data) - - # load all data measures - data = _load(data, threshold=threshold) - - if output_figure is not None: - output_figure = os.path.realpath(output_figure) - logger.info(f"Creating and saving plot at {output_figure}...") - os.makedirs(os.path.dirname(output_figure), exist_ok=True) - pdf = PdfPages(output_figure) - pdf.savefig(precision_recall_f1iso(data)) - pdf.savefig(roc_curve(data)) - pdf.close() - - logger.info("Tabulating performance summary...") - table = performance_table(data, table_format) - click.echo(table) - if output_table is not None: - output_table = os.path.realpath(output_table) - logger.info(f"Saving table at {output_table}...") - os.makedirs(os.path.dirname(output_table), exist_ok=True) - with open(output_table, "w") as f: - f.write(table) diff --git a/src/ptbench/scripts/relevance_analysis.py b/src/ptbench/scripts/relevance_analysis.py deleted file mode 100644 index 6be7abd5bdf49995a3910404d971f523893537a7..0000000000000000000000000000000000000000 --- a/src/ptbench/scripts/relevance_analysis.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Import copy import os import shutil. - -import numpy as np -import torch - -from matplotlib.backends.backend_pdf import PdfPages -from sklearn import metrics -from torch.utils.data import ConcatDataset, DataLoader - -from ..engine.predictor import run -from ..utils.plot import relevance_analysis_plot - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -# Relevance analysis using permutation feature importance -if relevance_analysis: - if isinstance(v, ConcatDataset) or not isinstance( - v._samples[0].data["data"], list - ): - logger.info( - "Relevance analysis only possible with radiological signs as input. Cancelling..." - ) - continue - - nb_features = len(v._samples[0].data["data"]) - - if nb_features == 1: - logger.info("Relevance analysis not possible with one feature") - else: - logger.info(f"Starting relevance analysis for subset '{k}'...") - - all_mse = [] - for f in range(nb_features): - v_original = copy.deepcopy(v) - - # Randomly permute feature values from all samples - v.random_permute(f) - - data_loader = DataLoader( - dataset=v, - batch_size=batch_size, - shuffle=False, - pin_memory=torch.cuda.is_available(), - ) - - predictions_with_mean = run( - model, - data_loader, - k, - accelerator, - output_folder + "_temp", - ) - - # Compute MSE between original and new predictions - all_mse.append( - metrics.mean_squared_error( - np.array(predictions, dtype=object)[:, 1], - np.array(predictions_with_mean, dtype=object)[:, 1], - ) - ) - - # Back to original values - v = v_original - - # Remove temporary folder - shutil.rmtree(output_folder + "_temp", ignore_errors=True) - - filepath = os.path.join(output_folder, k + "_RA.pdf") - logger.info(f"Creating and saving plot at {filepath}...") - os.makedirs(os.path.dirname(filepath), exist_ok=True) - pdf = PdfPages(filepath) - pdf.savefig( - relevance_analysis_plot( - all_mse, - title=k.capitalize() + " set relevance analysis", - ) - ) - pdf.close() -"""