Skip to content
Snippets Groups Projects
Commit 5e759a6d authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[scripts] Remove outdated scripts

parent d3d65b93
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Pipeline #77169 failed
......@@ -7,7 +7,6 @@ import click
from clapper.click import AliasedGroup
from . import (
compare,
config,
database,
evaluate,
......@@ -27,7 +26,6 @@ def cli():
pass
cli.add_command(compare.compare)
cli.add_command(config.config)
cli.add_command(database.database)
cli.add_command(evaluate.evaluate)
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import click
from clapper.click import verbosity_option
from clapper.logging import setup
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _validate_threshold(t, dataset):
"""Validates the user threshold selection.
Returns parsed threshold.
"""
if t is None:
return t
# we try to convert it to float first
t = float(t)
if t < 0.0 or t > 1.0:
raise ValueError("Thresholds must be within range [0.0, 1.0]")
return t
def _load(data, threshold):
"""Plots comparison chart of all evaluated models.
Parameters
----------
data : dict
A dict in which keys are the names of the systems and the values are
paths to ``predictions.csv`` style files.
threshold : :py:class:`float`
A threshold for the final classification.
Returns
-------
data : dict
A dict in which keys are the names of the systems and the values are
dictionaries that contain two keys:
* ``df``: A :py:class:`pandas.DataFrame` with the predictions data
loaded to
* ``threshold``: The ``threshold`` parameter set on the input
"""
import re
import pandas
import torch
use_threshold = threshold
logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'")
# loads all data
retval = {}
for name, predictions_path in data.items():
# Load predictions
logger.info(f"Loading predictions from {predictions_path}...")
pred_data = pandas.read_csv(predictions_path)
pred = (
torch.Tensor(
[
eval(
re.sub(" +", " ", x.replace("\n", "")).replace(" ", ",")
)
if isinstance(x, str)
else x
for x in pred_data["likelihood"].values
]
)
.double()
.flatten()
)
gt = (
torch.Tensor(
[
eval(
re.sub(" +", " ", x.replace("\n", "")).replace(" ", ",")
)
if isinstance(x, str)
else x
for x in pred_data["ground_truth"].values
]
)
.double()
.flatten()
)
pred_data["likelihood"] = pred
pred_data["ground_truth"] = gt
retval[name] = dict(df=pred_data, threshold=use_threshold)
return retval
@click.command(
epilog="""Examples:
\b
1. Compares system A and B, with their own predictions files:
.. code:: sh
ptbench compare -vv A path/to/A/predictions.csv B path/to/B/predictions.csv
""",
)
@click.argument(
"label_path",
nargs=-1,
)
@click.option(
"--output-figure",
"-f",
help="Path where write the output figure (any extension supported by "
"matplotlib is possible). If not provided, does not produce a figure.",
required=False,
default=None,
type=click.Path(dir_okay=False, file_okay=True),
)
@click.option(
"--table-format",
"-T",
help="The format to use for the comparison table",
show_default=True,
required=True,
default="rst",
type=click.Choice(__import__("tabulate").tabulate_formats),
)
@click.option(
"--output-table",
"-u",
help="Path where write the output table. If not provided, does not write "
"write a table to file, only to stdout.",
required=False,
default=None,
type=click.Path(dir_okay=False, file_okay=True),
)
@click.option(
"--threshold",
"-t",
help="This number is used to separate positive and negative cases "
"by thresholding their score.",
default=None,
show_default=False,
required=False,
)
@verbosity_option(logger=logger, expose_value=False)
def compare(
label_path, output_figure, table_format, output_table, threshold
) -> None:
"""Compares multiple systems together."""
import os
from matplotlib.backends.backend_pdf import PdfPages
from ..utils.plot import precision_recall_f1iso, roc_curve
from ..utils.table import performance_table
# hack to get a dictionary from arguments passed to input
if len(label_path) % 2 != 0:
raise click.ClickException(
"Input label-paths should be doubles"
" composed of name-path entries"
)
data = dict(zip(label_path[::2], label_path[1::2]))
threshold = _validate_threshold(threshold, data)
# load all data measures
data = _load(data, threshold=threshold)
if output_figure is not None:
output_figure = os.path.realpath(output_figure)
logger.info(f"Creating and saving plot at {output_figure}...")
os.makedirs(os.path.dirname(output_figure), exist_ok=True)
pdf = PdfPages(output_figure)
pdf.savefig(precision_recall_f1iso(data))
pdf.savefig(roc_curve(data))
pdf.close()
logger.info("Tabulating performance summary...")
table = performance_table(data, table_format)
click.echo(table)
if output_table is not None:
output_table = os.path.realpath(output_table)
logger.info(f"Saving table at {output_table}...")
os.makedirs(os.path.dirname(output_table), exist_ok=True)
with open(output_table, "w") as f:
f.write(table)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Import copy import os import shutil.
import numpy as np
import torch
from matplotlib.backends.backend_pdf import PdfPages
from sklearn import metrics
from torch.utils.data import ConcatDataset, DataLoader
from ..engine.predictor import run
from ..utils.plot import relevance_analysis_plot
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
# Relevance analysis using permutation feature importance
if relevance_analysis:
if isinstance(v, ConcatDataset) or not isinstance(
v._samples[0].data["data"], list
):
logger.info(
"Relevance analysis only possible with radiological signs as input. Cancelling..."
)
continue
nb_features = len(v._samples[0].data["data"])
if nb_features == 1:
logger.info("Relevance analysis not possible with one feature")
else:
logger.info(f"Starting relevance analysis for subset '{k}'...")
all_mse = []
for f in range(nb_features):
v_original = copy.deepcopy(v)
# Randomly permute feature values from all samples
v.random_permute(f)
data_loader = DataLoader(
dataset=v,
batch_size=batch_size,
shuffle=False,
pin_memory=torch.cuda.is_available(),
)
predictions_with_mean = run(
model,
data_loader,
k,
accelerator,
output_folder + "_temp",
)
# Compute MSE between original and new predictions
all_mse.append(
metrics.mean_squared_error(
np.array(predictions, dtype=object)[:, 1],
np.array(predictions_with_mean, dtype=object)[:, 1],
)
)
# Back to original values
v = v_original
# Remove temporary folder
shutil.rmtree(output_folder + "_temp", ignore_errors=True)
filepath = os.path.join(output_folder, k + "_RA.pdf")
logger.info(f"Creating and saving plot at {filepath}...")
os.makedirs(os.path.dirname(filepath), exist_ok=True)
pdf = PdfPages(filepath)
pdf.savefig(
relevance_analysis_plot(
all_mse,
title=k.capitalize() + " set relevance analysis",
)
)
pdf.close()
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment