Skip to content
Snippets Groups Projects
Commit d5b173b1 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Updated prediction scripts for compatibility with DataModules

Relevance analysis has been moved out of predict.py and into a separate
script. It is not currently functional.
parent 27363e33
No related branches found
No related tags found
No related merge requests found
Pipeline #75311 canceled
......@@ -109,7 +109,7 @@ class BaseDataModule(pl.LightningDataModule):
def predict_dataloader(self):
loaders_dict = {}
loaders_dict["train_dataloader"] = self.train_dataloader()
loaders_dict["train_loader"] = self.train_dataloader()
for k, v in self.val_dataloader().items():
loaders_dict[k] = v
......
import csv
import os
import time
from collections import defaultdict
......@@ -141,24 +142,34 @@ class LoggingCallback(Callback):
class PredictionsWriter(BasePredictionWriter):
"""Lightning callback to write predictions to a file."""
def __init__(self, logfile_name, logfile_fields, write_interval):
def __init__(self, output_dir, logfile_fields, write_interval):
super().__init__(write_interval)
self.logfile_name = logfile_name
self.output_dir = output_dir
self.logfile_fields = logfile_fields
def write_on_epoch_end(
self, trainer, pl_module, predictions, batch_indices
):
with open(self.logfile_name, "w") as logfile:
logwriter = csv.DictWriter(logfile, fieldnames=self.logfile_fields)
logwriter.writeheader()
for prediction in predictions:
logwriter.writerow(
{
"filename": prediction[0],
"likelihood": prediction[1].numpy(),
"ground_truth": prediction[2].numpy(),
}
)
logfile.flush()
for dataloader_idx, dataloader_results in enumerate(predictions):
dataloader_name = list(
trainer.datamodule.predict_dataloader().keys()
)[dataloader_idx].replace("_loader", "")
logfile = os.path.join(
self.output_dir, dataloader_name, "predictions.csv"
)
os.makedirs(os.path.dirname(logfile), exist_ok=True)
with open(logfile, "w") as l_f:
logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
logwriter.writeheader()
for prediction in dataloader_results:
logwriter.writerow(
{
"filename": prediction[0],
"likelihood": prediction[1].numpy(),
"ground_truth": prediction[2].numpy(),
}
)
l_f.flush()
......@@ -13,7 +13,7 @@ from .callbacks import PredictionsWriter
logger = logging.getLogger(__name__)
def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
def run(model, datamodule, accelerator, output_folder, grad_cams=False):
"""Runs inference on input data, outputs csv files with predictions.
Parameters
......@@ -24,10 +24,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
data_loader : py:class:`torch.torch.utils.data.DataLoader`
The pytorch Dataloader used to iterate over batches.
name : str
The local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files.
accelerator : str
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
......@@ -44,7 +40,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
all_predictions : list
All the predictions associated with filename and ground truth.
"""
output_folder = os.path.join(output_folder, name)
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
......@@ -58,7 +53,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
logger.info(f"Device: {devices}")
logfile_name = os.path.join(output_folder, "predictions.csv")
logfile_fields = ("filename", "likelihood", "ground_truth")
trainer = Trainer(
......@@ -66,7 +60,7 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
devices=devices,
callbacks=[
PredictionsWriter(
logfile_name=logfile_name,
output_dir=output_folder,
logfile_fields=logfile_fields,
write_interval="epoch",
),
......
......@@ -169,8 +169,9 @@ class PASA(pl.LightningModule):
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch["names"]
names = batch["name"]
images = batch["data"]
labels = batch["label"]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
......@@ -180,18 +181,34 @@ class PASA(pl.LightningModule):
if isinstance(outputs, list):
outputs = outputs[-1]
return {
f"dataloader_{dataloader_idx}_predictions": (
names[0],
torch.flatten(probabilities),
torch.flatten(batch[2]),
)
}
def on_predict_epoch_end(self):
# Need to cache predictions in the predict step, then reorder by key
# Clear prediction dict
raise NotImplementedError
results = (
names[0],
torch.flatten(probabilities),
torch.flatten(labels),
)
return results
# {
# f"dataloader_{dataloader_idx}_predictions": (
# names[0],
# torch.flatten(probabilities),
# torch.flatten(labels),
# )
# }
# def on_predict_epoch_end(self):
# retval = defaultdict(list)
# for dataloader_name, predictions in self.predictions_cache.items():
# for prediction in predictions:
# retval[dataloader_name]["name"].append(prediction[0])
# retval[dataloader_name]["prediction"].append(prediction[1])
# retval[dataloader_name]["label"].append(prediction[2])
# Need to cache predictions in the predict step, then reorder by key
# Clear prediction dict
# raise NotImplementedError
def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs
......
......@@ -41,7 +41,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption,
)
@click.option(
"--dataset",
"--datamodule",
"-d",
help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
"to be used for running prediction, possibly including all pre-processing "
......@@ -77,14 +77,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
required=True,
cls=ResourceOption,
)
@click.option(
"--relevance-analysis",
"-r",
help="If set, generate relevance analysis pdfs to indicate the relative"
"importance of each feature",
is_flag=True,
cls=ResourceOption,
)
@click.option(
"--grad-cams",
"-g",
......@@ -96,32 +88,27 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def predict(
output_folder,
model,
dataset,
datamodule,
batch_size,
accelerator,
weight,
relevance_analysis,
grad_cams,
**_,
) -> None:
"""Predicts Tuberculosis presence (probabilities) on input images."""
import copy
import os
import shutil
import numpy as np
import torch
from matplotlib.backends.backend_pdf import PdfPages
from sklearn import metrics
from torch.utils.data import ConcatDataset, DataLoader
from ..data.datamodule import DataModule
from ..engine.predictor import run
from ..utils.plot import relevance_analysis_plot
dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
datamodule = datamodule(
batch_size=batch_size,
)
logger.info(f"Loading checkpoint from {weight}")
model = model.load_from_checkpoint(weight, strict=False)
......@@ -141,83 +128,4 @@ def predict(
)
pdf.close()
for k, v in dataset.items():
if k.startswith("_"):
logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
continue
logger.info(f"Running inference on '{k}' set...")
datamodule = DataModule(
v,
train_batch_size=batch_size,
)
predictions = run(
model, datamodule, k, accelerator, output_folder, grad_cams
)
# Relevance analysis using permutation feature importance
if relevance_analysis:
if isinstance(v, ConcatDataset) or not isinstance(
v._samples[0].data["data"], list
):
logger.info(
"Relevance analysis only possible with radiological signs as input. Cancelling..."
)
continue
nb_features = len(v._samples[0].data["data"])
if nb_features == 1:
logger.info("Relevance analysis not possible with one feature")
else:
logger.info(f"Starting relevance analysis for subset '{k}'...")
all_mse = []
for f in range(nb_features):
v_original = copy.deepcopy(v)
# Randomly permute feature values from all samples
v.random_permute(f)
data_loader = DataLoader(
dataset=v,
batch_size=batch_size,
shuffle=False,
pin_memory=torch.cuda.is_available(),
)
predictions_with_mean = run(
model,
data_loader,
k,
accelerator,
output_folder + "_temp",
)
# Compute MSE between original and new predictions
all_mse.append(
metrics.mean_squared_error(
np.array(predictions, dtype=object)[:, 1],
np.array(predictions_with_mean, dtype=object)[:, 1],
)
)
# Back to original values
v = v_original
# Remove temporary folder
shutil.rmtree(output_folder + "_temp", ignore_errors=True)
filepath = os.path.join(output_folder, k + "_RA.pdf")
logger.info(f"Creating and saving plot at {filepath}...")
os.makedirs(os.path.dirname(filepath), exist_ok=True)
pdf = PdfPages(filepath)
pdf.savefig(
relevance_analysis_plot(
all_mse,
title=k.capitalize() + " set relevance analysis",
)
)
pdf.close()
_ = run(model, datamodule, accelerator, output_folder, grad_cams)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Import copy import os import shutil.
import numpy as np
import torch
from matplotlib.backends.backend_pdf import PdfPages
from sklearn import metrics
from torch.utils.data import ConcatDataset, DataLoader
from ..engine.predictor import run
from ..utils.plot import relevance_analysis_plot
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
# Relevance analysis using permutation feature importance
if relevance_analysis:
if isinstance(v, ConcatDataset) or not isinstance(
v._samples[0].data["data"], list
):
logger.info(
"Relevance analysis only possible with radiological signs as input. Cancelling..."
)
continue
nb_features = len(v._samples[0].data["data"])
if nb_features == 1:
logger.info("Relevance analysis not possible with one feature")
else:
logger.info(f"Starting relevance analysis for subset '{k}'...")
all_mse = []
for f in range(nb_features):
v_original = copy.deepcopy(v)
# Randomly permute feature values from all samples
v.random_permute(f)
data_loader = DataLoader(
dataset=v,
batch_size=batch_size,
shuffle=False,
pin_memory=torch.cuda.is_available(),
)
predictions_with_mean = run(
model,
data_loader,
k,
accelerator,
output_folder + "_temp",
)
# Compute MSE between original and new predictions
all_mse.append(
metrics.mean_squared_error(
np.array(predictions, dtype=object)[:, 1],
np.array(predictions_with_mean, dtype=object)[:, 1],
)
)
# Back to original values
v = v_original
# Remove temporary folder
shutil.rmtree(output_folder + "_temp", ignore_errors=True)
filepath = os.path.join(output_folder, k + "_RA.pdf")
logger.info(f"Creating and saving plot at {filepath}...")
os.makedirs(os.path.dirname(filepath), exist_ok=True)
pdf = PdfPages(filepath)
pdf.savefig(
relevance_analysis_plot(
all_mse,
title=k.capitalize() + " set relevance analysis",
)
)
pdf.close()
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment