diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py index 5dfbbd7e6e0b00df0d1a6208a7bb70351b1cb417..3bcd441a59a8e08c2e4d9aa33808c8d651026100 100644 --- a/src/ptbench/data/base_datamodule.py +++ b/src/ptbench/data/base_datamodule.py @@ -109,7 +109,7 @@ class BaseDataModule(pl.LightningDataModule): def predict_dataloader(self): loaders_dict = {} - loaders_dict["train_dataloader"] = self.train_dataloader() + loaders_dict["train_loader"] = self.train_dataloader() for k, v in self.val_dataloader().items(): loaders_dict[k] = v diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index c78d52ea806310c087ecda943b269ec4129a4d11..d0ac43f98e21b8ce6803797d6a1fde38c6302660 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -1,4 +1,5 @@ import csv +import os import time from collections import defaultdict @@ -141,24 +142,34 @@ class LoggingCallback(Callback): class PredictionsWriter(BasePredictionWriter): """Lightning callback to write predictions to a file.""" - def __init__(self, logfile_name, logfile_fields, write_interval): + def __init__(self, output_dir, logfile_fields, write_interval): super().__init__(write_interval) - self.logfile_name = logfile_name + self.output_dir = output_dir self.logfile_fields = logfile_fields def write_on_epoch_end( self, trainer, pl_module, predictions, batch_indices ): - with open(self.logfile_name, "w") as logfile: - logwriter = csv.DictWriter(logfile, fieldnames=self.logfile_fields) - logwriter.writeheader() - - for prediction in predictions: - logwriter.writerow( - { - "filename": prediction[0], - "likelihood": prediction[1].numpy(), - "ground_truth": prediction[2].numpy(), - } - ) - logfile.flush() + for dataloader_idx, dataloader_results in enumerate(predictions): + dataloader_name = list( + trainer.datamodule.predict_dataloader().keys() + )[dataloader_idx].replace("_loader", "") + + logfile = os.path.join( + self.output_dir, dataloader_name, "predictions.csv" + ) + os.makedirs(os.path.dirname(logfile), exist_ok=True) + + with open(logfile, "w") as l_f: + logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields) + logwriter.writeheader() + + for prediction in dataloader_results: + logwriter.writerow( + { + "filename": prediction[0], + "likelihood": prediction[1].numpy(), + "ground_truth": prediction[2].numpy(), + } + ) + l_f.flush() diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 4535b368bca54daf97fa350b91516a126eff319a..5dcbb79c9fd0a8c32f9d269f8302b888da56be84 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -13,7 +13,7 @@ from .callbacks import PredictionsWriter logger = logging.getLogger(__name__) -def run(model, datamodule, name, accelerator, output_folder, grad_cams=False): +def run(model, datamodule, accelerator, output_folder, grad_cams=False): """Runs inference on input data, outputs csv files with predictions. Parameters @@ -24,10 +24,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False): data_loader : py:class:`torch.torch.utils.data.DataLoader` The pytorch Dataloader used to iterate over batches. - name : str - The local name of this dataset (e.g. ``train``, or ``test``), to be - used when saving measures files. - accelerator : str A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) @@ -44,7 +40,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False): all_predictions : list All the predictions associated with filename and ground truth. """ - output_folder = os.path.join(output_folder, name) logger.info(f"Output folder: {output_folder}") os.makedirs(output_folder, exist_ok=True) @@ -58,7 +53,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False): logger.info(f"Device: {devices}") - logfile_name = os.path.join(output_folder, "predictions.csv") logfile_fields = ("filename", "likelihood", "ground_truth") trainer = Trainer( @@ -66,7 +60,7 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False): devices=devices, callbacks=[ PredictionsWriter( - logfile_name=logfile_name, + output_dir=output_folder, logfile_fields=logfile_fields, write_interval="epoch", ), diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index d9e86b7055ed1ca3a06d807ab5d02b175b2f0cef..9da9702f31436df441cdb56d9415cc06e6829623 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -169,8 +169,9 @@ class PASA(pl.LightningModule): return {f"extra_validation_loss_{dataloader_idx}": validation_loss} def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - names = batch["names"] + names = batch["name"] images = batch["data"] + labels = batch["label"] outputs = self(images) probabilities = torch.sigmoid(outputs) @@ -180,18 +181,34 @@ class PASA(pl.LightningModule): if isinstance(outputs, list): outputs = outputs[-1] - return { - f"dataloader_{dataloader_idx}_predictions": ( - names[0], - torch.flatten(probabilities), - torch.flatten(batch[2]), - ) - } - - def on_predict_epoch_end(self): - # Need to cache predictions in the predict step, then reorder by key - # Clear prediction dict - raise NotImplementedError + results = ( + names[0], + torch.flatten(probabilities), + torch.flatten(labels), + ) + + return results + # { + # f"dataloader_{dataloader_idx}_predictions": ( + # names[0], + # torch.flatten(probabilities), + # torch.flatten(labels), + # ) + # } + + # def on_predict_epoch_end(self): + + # retval = defaultdict(list) + + # for dataloader_name, predictions in self.predictions_cache.items(): + # for prediction in predictions: + # retval[dataloader_name]["name"].append(prediction[0]) + # retval[dataloader_name]["prediction"].append(prediction[1]) + # retval[dataloader_name]["label"].append(prediction[2]) + + # Need to cache predictions in the predict step, then reorder by key + # Clear prediction dict + # raise NotImplementedError def configure_optimizers(self): # Dynamically instantiates the optimizer given the configs diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 52dc98f540247e0743608f53b5496f0d1ffbf89e..a78d74b41d8f75b3f4466a890f206e4b2503a84c 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -41,7 +41,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @click.option( - "--dataset", + "--datamodule", "-d", help="A torch.utils.data.dataset.Dataset instance implementing a dataset " "to be used for running prediction, possibly including all pre-processing " @@ -77,14 +77,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") required=True, cls=ResourceOption, ) -@click.option( - "--relevance-analysis", - "-r", - help="If set, generate relevance analysis pdfs to indicate the relative" - "importance of each feature", - is_flag=True, - cls=ResourceOption, -) @click.option( "--grad-cams", "-g", @@ -96,32 +88,27 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def predict( output_folder, model, - dataset, + datamodule, batch_size, accelerator, weight, - relevance_analysis, grad_cams, **_, ) -> None: """Predicts Tuberculosis presence (probabilities) on input images.""" - import copy import os - import shutil import numpy as np - import torch from matplotlib.backends.backend_pdf import PdfPages - from sklearn import metrics - from torch.utils.data import ConcatDataset, DataLoader - from ..data.datamodule import DataModule from ..engine.predictor import run from ..utils.plot import relevance_analysis_plot - dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) + datamodule = datamodule( + batch_size=batch_size, + ) logger.info(f"Loading checkpoint from {weight}") model = model.load_from_checkpoint(weight, strict=False) @@ -141,83 +128,4 @@ def predict( ) pdf.close() - for k, v in dataset.items(): - if k.startswith("_"): - logger.info(f"Skipping dataset '{k}' (not to be evaluated)") - continue - - logger.info(f"Running inference on '{k}' set...") - - datamodule = DataModule( - v, - train_batch_size=batch_size, - ) - - predictions = run( - model, datamodule, k, accelerator, output_folder, grad_cams - ) - - # Relevance analysis using permutation feature importance - if relevance_analysis: - if isinstance(v, ConcatDataset) or not isinstance( - v._samples[0].data["data"], list - ): - logger.info( - "Relevance analysis only possible with radiological signs as input. Cancelling..." - ) - continue - - nb_features = len(v._samples[0].data["data"]) - - if nb_features == 1: - logger.info("Relevance analysis not possible with one feature") - else: - logger.info(f"Starting relevance analysis for subset '{k}'...") - - all_mse = [] - for f in range(nb_features): - v_original = copy.deepcopy(v) - - # Randomly permute feature values from all samples - v.random_permute(f) - - data_loader = DataLoader( - dataset=v, - batch_size=batch_size, - shuffle=False, - pin_memory=torch.cuda.is_available(), - ) - - predictions_with_mean = run( - model, - data_loader, - k, - accelerator, - output_folder + "_temp", - ) - - # Compute MSE between original and new predictions - all_mse.append( - metrics.mean_squared_error( - np.array(predictions, dtype=object)[:, 1], - np.array(predictions_with_mean, dtype=object)[:, 1], - ) - ) - - # Back to original values - v = v_original - - # Remove temporary folder - shutil.rmtree(output_folder + "_temp", ignore_errors=True) - - filepath = os.path.join(output_folder, k + "_RA.pdf") - logger.info(f"Creating and saving plot at {filepath}...") - os.makedirs(os.path.dirname(filepath), exist_ok=True) - pdf = PdfPages(filepath) - pdf.savefig( - relevance_analysis_plot( - all_mse, - title=k.capitalize() + " set relevance analysis", - ) - ) - pdf.close() + _ = run(model, datamodule, accelerator, output_folder, grad_cams) diff --git a/src/ptbench/scripts/relevance_analysis.py b/src/ptbench/scripts/relevance_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..1771c107f1317ad181838cf408a76d2c955ad30c --- /dev/null +++ b/src/ptbench/scripts/relevance_analysis.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Import copy import os import shutil. + +import numpy as np +import torch + +from matplotlib.backends.backend_pdf import PdfPages +from sklearn import metrics +from torch.utils.data import ConcatDataset, DataLoader + +from ..engine.predictor import run +from ..utils.plot import relevance_analysis_plot + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +# Relevance analysis using permutation feature importance +if relevance_analysis: + if isinstance(v, ConcatDataset) or not isinstance( + v._samples[0].data["data"], list + ): + logger.info( + "Relevance analysis only possible with radiological signs as input. Cancelling..." + ) + continue + + nb_features = len(v._samples[0].data["data"]) + + if nb_features == 1: + logger.info("Relevance analysis not possible with one feature") + else: + logger.info(f"Starting relevance analysis for subset '{k}'...") + + all_mse = [] + for f in range(nb_features): + v_original = copy.deepcopy(v) + + # Randomly permute feature values from all samples + v.random_permute(f) + + data_loader = DataLoader( + dataset=v, + batch_size=batch_size, + shuffle=False, + pin_memory=torch.cuda.is_available(), + ) + + predictions_with_mean = run( + model, + data_loader, + k, + accelerator, + output_folder + "_temp", + ) + + # Compute MSE between original and new predictions + all_mse.append( + metrics.mean_squared_error( + np.array(predictions, dtype=object)[:, 1], + np.array(predictions_with_mean, dtype=object)[:, 1], + ) + ) + + # Back to original values + v = v_original + + # Remove temporary folder + shutil.rmtree(output_folder + "_temp", ignore_errors=True) + + filepath = os.path.join(output_folder, k + "_RA.pdf") + logger.info(f"Creating and saving plot at {filepath}...") + os.makedirs(os.path.dirname(filepath), exist_ok=True) + pdf = PdfPages(filepath) + pdf.savefig( + relevance_analysis_plot( + all_mse, + title=k.capitalize() + " set relevance analysis", + ) + ) + pdf.close() +"""