Skip to content
Snippets Groups Projects
Commit 1729b307 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Updated prediction scripts for compatibility with DataModules

Relevance analysis has been moved out of predict.py and into a separate
script. It is not currently functional.
parent de0ab7ff
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -109,7 +109,7 @@ class BaseDataModule(pl.LightningDataModule): ...@@ -109,7 +109,7 @@ class BaseDataModule(pl.LightningDataModule):
def predict_dataloader(self): def predict_dataloader(self):
loaders_dict = {} loaders_dict = {}
loaders_dict["train_dataloader"] = self.train_dataloader() loaders_dict["train_loader"] = self.train_dataloader()
for k, v in self.val_dataloader().items(): for k, v in self.val_dataloader().items():
loaders_dict[k] = v loaders_dict[k] = v
......
import csv import csv
import os
import time import time
from collections import defaultdict from collections import defaultdict
...@@ -141,24 +142,34 @@ class LoggingCallback(Callback): ...@@ -141,24 +142,34 @@ class LoggingCallback(Callback):
class PredictionsWriter(BasePredictionWriter): class PredictionsWriter(BasePredictionWriter):
"""Lightning callback to write predictions to a file.""" """Lightning callback to write predictions to a file."""
def __init__(self, logfile_name, logfile_fields, write_interval): def __init__(self, output_dir, logfile_fields, write_interval):
super().__init__(write_interval) super().__init__(write_interval)
self.logfile_name = logfile_name self.output_dir = output_dir
self.logfile_fields = logfile_fields self.logfile_fields = logfile_fields
def write_on_epoch_end( def write_on_epoch_end(
self, trainer, pl_module, predictions, batch_indices self, trainer, pl_module, predictions, batch_indices
): ):
with open(self.logfile_name, "w") as logfile: for dataloader_idx, dataloader_results in enumerate(predictions):
logwriter = csv.DictWriter(logfile, fieldnames=self.logfile_fields) dataloader_name = list(
logwriter.writeheader() trainer.datamodule.predict_dataloader().keys()
)[dataloader_idx].replace("_loader", "")
for prediction in predictions:
logwriter.writerow( logfile = os.path.join(
{ self.output_dir, dataloader_name, "predictions.csv"
"filename": prediction[0], )
"likelihood": prediction[1].numpy(), os.makedirs(os.path.dirname(logfile), exist_ok=True)
"ground_truth": prediction[2].numpy(),
} with open(logfile, "w") as l_f:
) logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
logfile.flush() logwriter.writeheader()
for prediction in dataloader_results:
logwriter.writerow(
{
"filename": prediction[0],
"likelihood": prediction[1].numpy(),
"ground_truth": prediction[2].numpy(),
}
)
l_f.flush()
...@@ -13,7 +13,7 @@ from .callbacks import PredictionsWriter ...@@ -13,7 +13,7 @@ from .callbacks import PredictionsWriter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run(model, datamodule, name, accelerator, output_folder, grad_cams=False): def run(model, datamodule, accelerator, output_folder, grad_cams=False):
"""Runs inference on input data, outputs csv files with predictions. """Runs inference on input data, outputs csv files with predictions.
Parameters Parameters
...@@ -24,10 +24,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False): ...@@ -24,10 +24,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
data_loader : py:class:`torch.torch.utils.data.DataLoader` data_loader : py:class:`torch.torch.utils.data.DataLoader`
The pytorch Dataloader used to iterate over batches. The pytorch Dataloader used to iterate over batches.
name : str
The local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files.
accelerator : str accelerator : str
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
...@@ -44,7 +40,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False): ...@@ -44,7 +40,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
all_predictions : list all_predictions : list
All the predictions associated with filename and ground truth. All the predictions associated with filename and ground truth.
""" """
output_folder = os.path.join(output_folder, name)
logger.info(f"Output folder: {output_folder}") logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
...@@ -58,7 +53,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False): ...@@ -58,7 +53,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
logger.info(f"Device: {devices}") logger.info(f"Device: {devices}")
logfile_name = os.path.join(output_folder, "predictions.csv")
logfile_fields = ("filename", "likelihood", "ground_truth") logfile_fields = ("filename", "likelihood", "ground_truth")
trainer = Trainer( trainer = Trainer(
...@@ -66,7 +60,7 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False): ...@@ -66,7 +60,7 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
devices=devices, devices=devices,
callbacks=[ callbacks=[
PredictionsWriter( PredictionsWriter(
logfile_name=logfile_name, output_dir=output_folder,
logfile_fields=logfile_fields, logfile_fields=logfile_fields,
write_interval="epoch", write_interval="epoch",
), ),
......
...@@ -169,8 +169,9 @@ class PASA(pl.LightningModule): ...@@ -169,8 +169,9 @@ class PASA(pl.LightningModule):
return {f"extra_validation_loss_{dataloader_idx}": validation_loss} return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch["names"] names = batch["name"]
images = batch["data"] images = batch["data"]
labels = batch["label"]
outputs = self(images) outputs = self(images)
probabilities = torch.sigmoid(outputs) probabilities = torch.sigmoid(outputs)
...@@ -180,18 +181,34 @@ class PASA(pl.LightningModule): ...@@ -180,18 +181,34 @@ class PASA(pl.LightningModule):
if isinstance(outputs, list): if isinstance(outputs, list):
outputs = outputs[-1] outputs = outputs[-1]
return { results = (
f"dataloader_{dataloader_idx}_predictions": ( names[0],
names[0], torch.flatten(probabilities),
torch.flatten(probabilities), torch.flatten(labels),
torch.flatten(batch[2]), )
)
} return results
# {
def on_predict_epoch_end(self): # f"dataloader_{dataloader_idx}_predictions": (
# Need to cache predictions in the predict step, then reorder by key # names[0],
# Clear prediction dict # torch.flatten(probabilities),
raise NotImplementedError # torch.flatten(labels),
# )
# }
# def on_predict_epoch_end(self):
# retval = defaultdict(list)
# for dataloader_name, predictions in self.predictions_cache.items():
# for prediction in predictions:
# retval[dataloader_name]["name"].append(prediction[0])
# retval[dataloader_name]["prediction"].append(prediction[1])
# retval[dataloader_name]["label"].append(prediction[2])
# Need to cache predictions in the predict step, then reorder by key
# Clear prediction dict
# raise NotImplementedError
def configure_optimizers(self): def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs # Dynamically instantiates the optimizer given the configs
......
...@@ -41,7 +41,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -41,7 +41,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--dataset", "--datamodule",
"-d", "-d",
help="A torch.utils.data.dataset.Dataset instance implementing a dataset " help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
"to be used for running prediction, possibly including all pre-processing " "to be used for running prediction, possibly including all pre-processing "
...@@ -77,14 +77,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -77,14 +77,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option(
"--relevance-analysis",
"-r",
help="If set, generate relevance analysis pdfs to indicate the relative"
"importance of each feature",
is_flag=True,
cls=ResourceOption,
)
@click.option( @click.option(
"--grad-cams", "--grad-cams",
"-g", "-g",
...@@ -96,32 +88,27 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -96,32 +88,27 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def predict( def predict(
output_folder, output_folder,
model, model,
dataset, datamodule,
batch_size, batch_size,
accelerator, accelerator,
weight, weight,
relevance_analysis,
grad_cams, grad_cams,
**_, **_,
) -> None: ) -> None:
"""Predicts Tuberculosis presence (probabilities) on input images.""" """Predicts Tuberculosis presence (probabilities) on input images."""
import copy
import os import os
import shutil
import numpy as np import numpy as np
import torch
from matplotlib.backends.backend_pdf import PdfPages from matplotlib.backends.backend_pdf import PdfPages
from sklearn import metrics
from torch.utils.data import ConcatDataset, DataLoader
from ..data.datamodule import DataModule
from ..engine.predictor import run from ..engine.predictor import run
from ..utils.plot import relevance_analysis_plot from ..utils.plot import relevance_analysis_plot
dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) datamodule = datamodule(
batch_size=batch_size,
)
logger.info(f"Loading checkpoint from {weight}") logger.info(f"Loading checkpoint from {weight}")
model = model.load_from_checkpoint(weight, strict=False) model = model.load_from_checkpoint(weight, strict=False)
...@@ -141,83 +128,4 @@ def predict( ...@@ -141,83 +128,4 @@ def predict(
) )
pdf.close() pdf.close()
for k, v in dataset.items(): _ = run(model, datamodule, accelerator, output_folder, grad_cams)
if k.startswith("_"):
logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
continue
logger.info(f"Running inference on '{k}' set...")
datamodule = DataModule(
v,
train_batch_size=batch_size,
)
predictions = run(
model, datamodule, k, accelerator, output_folder, grad_cams
)
# Relevance analysis using permutation feature importance
if relevance_analysis:
if isinstance(v, ConcatDataset) or not isinstance(
v._samples[0].data["data"], list
):
logger.info(
"Relevance analysis only possible with radiological signs as input. Cancelling..."
)
continue
nb_features = len(v._samples[0].data["data"])
if nb_features == 1:
logger.info("Relevance analysis not possible with one feature")
else:
logger.info(f"Starting relevance analysis for subset '{k}'...")
all_mse = []
for f in range(nb_features):
v_original = copy.deepcopy(v)
# Randomly permute feature values from all samples
v.random_permute(f)
data_loader = DataLoader(
dataset=v,
batch_size=batch_size,
shuffle=False,
pin_memory=torch.cuda.is_available(),
)
predictions_with_mean = run(
model,
data_loader,
k,
accelerator,
output_folder + "_temp",
)
# Compute MSE between original and new predictions
all_mse.append(
metrics.mean_squared_error(
np.array(predictions, dtype=object)[:, 1],
np.array(predictions_with_mean, dtype=object)[:, 1],
)
)
# Back to original values
v = v_original
# Remove temporary folder
shutil.rmtree(output_folder + "_temp", ignore_errors=True)
filepath = os.path.join(output_folder, k + "_RA.pdf")
logger.info(f"Creating and saving plot at {filepath}...")
os.makedirs(os.path.dirname(filepath), exist_ok=True)
pdf = PdfPages(filepath)
pdf.savefig(
relevance_analysis_plot(
all_mse,
title=k.capitalize() + " set relevance analysis",
)
)
pdf.close()
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Import copy import os import shutil.
import numpy as np
import torch
from matplotlib.backends.backend_pdf import PdfPages
from sklearn import metrics
from torch.utils.data import ConcatDataset, DataLoader
from ..engine.predictor import run
from ..utils.plot import relevance_analysis_plot
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
# Relevance analysis using permutation feature importance
if relevance_analysis:
if isinstance(v, ConcatDataset) or not isinstance(
v._samples[0].data["data"], list
):
logger.info(
"Relevance analysis only possible with radiological signs as input. Cancelling..."
)
continue
nb_features = len(v._samples[0].data["data"])
if nb_features == 1:
logger.info("Relevance analysis not possible with one feature")
else:
logger.info(f"Starting relevance analysis for subset '{k}'...")
all_mse = []
for f in range(nb_features):
v_original = copy.deepcopy(v)
# Randomly permute feature values from all samples
v.random_permute(f)
data_loader = DataLoader(
dataset=v,
batch_size=batch_size,
shuffle=False,
pin_memory=torch.cuda.is_available(),
)
predictions_with_mean = run(
model,
data_loader,
k,
accelerator,
output_folder + "_temp",
)
# Compute MSE between original and new predictions
all_mse.append(
metrics.mean_squared_error(
np.array(predictions, dtype=object)[:, 1],
np.array(predictions_with_mean, dtype=object)[:, 1],
)
)
# Back to original values
v = v_original
# Remove temporary folder
shutil.rmtree(output_folder + "_temp", ignore_errors=True)
filepath = os.path.join(output_folder, k + "_RA.pdf")
logger.info(f"Creating and saving plot at {filepath}...")
os.makedirs(os.path.dirname(filepath), exist_ok=True)
pdf = PdfPages(filepath)
pdf.savefig(
relevance_analysis_plot(
all_mse,
title=k.capitalize() + " set relevance analysis",
)
)
pdf.close()
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment