From e8d70f15b5ef4c541d034c39735a7bdc946c078d Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Apr 2023 11:48:44 +0200 Subject: [PATCH] Moved prediction to lightning Some DensenetRS-specific code has been removed from predictor.py, it will have to be re-added directly inside the forward of the DensenetRS model. --- src/ptbench/engine/callbacks.py | 18 ++- src/ptbench/engine/predictor.py | 237 +++----------------------------- src/ptbench/models/densenet.py | 15 +- src/ptbench/models/pasa.py | 14 ++ src/ptbench/scripts/predict.py | 13 +- 5 files changed, 55 insertions(+), 242 deletions(-) diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index 20d457d1..a01a7e19 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -91,14 +91,12 @@ class PredictionsWriter(BasePredictionWriter): logwriter = csv.DictWriter(logfile, fieldnames=self.logfile_fields) logwriter.writeheader() - # We should only get a single epoch here - for epoch in predictions: - for prediction in epoch: - logwriter.writerow( - { - "filename": prediction[0], - "likelihood": prediction[1].numpy(), - "ground_truth": prediction[2].numpy(), - } - ) + for prediction in predictions: + logwriter.writerow( + { + "filename": prediction[0], + "likelihood": prediction[1].numpy(), + "ground_truth": prediction[2].numpy(), + } + ) logfile.flush() diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 6c4dd4af..31c3e4c6 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -2,44 +2,15 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import csv -import datetime import logging import os -import shutil -import time -import matplotlib.pyplot as plt -import numpy -import PIL -import torch +from pytorch_lightning import Trainer -from matplotlib.gridspec import GridSpec -from matplotlib.patches import Rectangle -from torchvision import transforms -from tqdm import tqdm - -from ..utils.grad_cams import GradCAM +from .callbacks import PredictionsWriter logger = logging.getLogger(__name__) -colors = [ - [(47, 79, 79), "Cardiomegaly"], - [(255, 0, 0), "Emphysema"], - [(0, 128, 0), "Pleural effusion"], - [(0, 0, 128), "Hernia"], - [(255, 84, 0), "Infiltration"], - [(222, 184, 135), "Mass"], - [(0, 255, 0), "Nodule"], - [(0, 191, 255), "Atelectasis"], - [(0, 0, 255), "Pneumothorax"], - [(255, 0, 255), "Pleural thickening"], - [(255, 255, 0), "Pneumonia"], - [(126, 0, 255), "Fibrosis"], - [(255, 20, 147), "Edema"], - [(0, 255, 180), "Consolidation"], -] - def run(model, data_loader, name, device, output_folder, grad_cams=False): """Runs inference on input data, outputs HDF5 files with predictions. @@ -82,192 +53,18 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False): logfile_name = os.path.join(output_folder, "predictions.csv") logfile_fields = ("filename", "likelihood", "ground_truth") - if os.path.exists(logfile_name): - backup = logfile_name + "~" - if os.path.exists(backup): - os.unlink(backup) - shutil.move(logfile_name, backup) - - if grad_cams: - grad_folder = os.path.join(output_folder, "cams") - logger.info(f"Grad cams folder: {grad_folder}") - os.makedirs(grad_folder, exist_ok=True) - - with open(logfile_name, "a+", newline="") as logfile: - logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) - - logwriter.writeheader() - - model.eval() # set evaluation mode - model.to(device) # set/cast parameters to device - - # Setup timers - start_total_time = time.time() - times = [] - len_samples = [] - - all_predictions = [] - - for samples in tqdm( - data_loader, - desc="batches", - leave=False, - disable=None, - ): - names = samples[0] - images = samples[1].to( - device=device, non_blocking=torch.cuda.is_available() - ) - - # Gradcams generation - allowed_models = ["DensenetRS"] - if grad_cams and model.name in allowed_models: - gcam = GradCAM(model=model) - probs, ids = gcam.forward(images) - - # To store signs overlays - cams_img = dict() - - # Top k number of radiological signs for which we generate cams - topk = 1 - - for i in range(topk): - # Keep only "positive" signs - if probs[:, [i]] > 0.5: - # Grad-CAM - b = ids[:, [i]] - gcam.backward(ids=ids[:, [i]]) - regions = gcam.generate( - target_layer="model_ft.features.denseblock4.denselayer16.conv2" - ) - - for j in range(len(images)): - current_cam = regions[j, 0].cpu().numpy() - current_cam[current_cam < 0.75] = 0.0 - current_cam[current_cam >= 0.75] = 1.0 - current_cam = PIL.Image.fromarray( - numpy.uint8(current_cam * 255), "L" - ) - cams_img[b.item()] = [ - current_cam, - round(probs[:, [i]].item(), 2), - ] - - if len(cams_img) > 0: - # Convert original image tensor into PIL Image - original_image = transforms.ToPILImage(mode="RGB")( - images[0] - ) - - for sign_id, label_prob in cams_img.items(): - label = label_prob[0] - - # Create the colored overlay for current sign - colored_sign = PIL.ImageOps.colorize( - label.convert("L"), (0, 0, 0), colors[sign_id][0] - ) - - # blend image and label together - first blend to get signs drawn with a - # slight "label_color" tone on top, then composite with original image, to - # avoid loosing brightness. - retval = PIL.Image.blend( - original_image, colored_sign, 0.5 - ) - composite_mask = PIL.ImageOps.invert(label.convert("L")) - original_image = PIL.Image.composite( - original_image, retval, composite_mask - ) - - handles = [] - labels = [] - for i, v in enumerate(colors): - # If sign present on image - if cams_img.get(i) is not None: - handles.append( - Rectangle( - (0, 0), - 1, - 1, - color=tuple(v / 255 for v in v[0]), - ) - ) - labels.append( - v[1] + " (" + str(cams_img[i][1]) + ")" - ) - - gs = GridSpec(6, 1) - fig = plt.figure(figsize=(10, 11)) - ax1 = fig.add_subplot(gs[:-1, :]) # For the plot - ax2 = fig.add_subplot(gs[-1, :]) # For the legend - - ax1.imshow(original_image) - ax1.axis("off") - ax2.legend( - handles, labels, mode="expand", ncol=3, frameon=False - ) - ax2.axis("off") - - original_filename = ( - samples[0][0].split("/")[-1].split(".")[0] - ) - cam_filename = os.path.join( - grad_folder, original_filename + "_cam.png" - ) - fig.savefig(cam_filename) - - with torch.no_grad(): - start_time = time.perf_counter() - outputs = model(images) - probabilities = torch.sigmoid(outputs) - - # necessary check for HED architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(outputs, list): - outputs = outputs[-1] - - # predictions = sigmoid(outputs) - - batch_time = time.perf_counter() - start_time - times.append(batch_time) - len_samples.append(len(images)) - - logdata = ( - ("filename", f"{names[0]}"), - ( - "likelihood", - f"{torch.flatten(probabilities).data.cpu().numpy()}", - ), - ( - "ground_truth", - f"{torch.flatten(samples[2]).data.cpu().numpy()}", - ), - ) - - logwriter.writerow(dict(k for k in logdata)) - logfile.flush() - tqdm.write(" | ".join([f"{k}: {v}" for (k, v) in logdata[:4]])) - - # Keep prediction for relevance analysis - all_predictions.append( - [ - names[0], - torch.flatten(probabilities).data.cpu().numpy(), - torch.flatten(samples[2]).data.cpu().numpy(), - ] - ) - - # report operational summary - total_time = datetime.timedelta( - seconds=int(time.time() - start_total_time) - ) - logger.info(f"Total time: {total_time}") - - average_batch_time = numpy.mean(times) - logger.info(f"Average batch time: {average_batch_time:g}s") - - average_image_time = numpy.sum( - numpy.array(times) * len_samples - ) / float(sum(len_samples)) - logger.info(f"Average image time: {average_image_time:g}s") - - return all_predictions + trainer = Trainer( + accelerator="auto", + devices="auto", + callbacks=[ + PredictionsWriter( + logfile_name=logfile_name, + logfile_fields=logfile_fields, + write_interval="epoch", + ), + ], + ) + + all_predictions = trainer.predict(model, data_loader) + + return all_predictions diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 33476d42..17373b79 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -60,7 +60,6 @@ class Densenet(pl.LightningModule): tensor : :py:class:`torch.Tensor` """ - x = self.normalizer(x) x = self.model_ft(x) @@ -98,6 +97,20 @@ class Densenet(pl.LightningModule): return {"validation_loss": validation_loss} + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] + + outputs = self(images) + probabilities = torch.sigmoid(outputs) + + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] + + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + def configure_optimizers(self): # Dynamically instantiates the optimizer given the configs optimizer = getattr(torch.optim, self.hparams.optimizer)( diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 4fd816e1..b31fa21d 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -186,6 +186,20 @@ class PASA(pl.LightningModule): return {"validation_loss": validation_loss} + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] + + outputs = self(images) + probabilities = torch.sigmoid(outputs) + + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] + + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + def configure_optimizers(self): # Dynamically instantiates the optimizer given the configs optimizer = getattr(torch.optim, self.hparams.optimizer)( diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 51275fc4..82939f25 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -63,6 +63,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ) @click.option( "--device", + "-d", help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', show_default=True, required=True, @@ -117,21 +118,11 @@ def predict( from torch.utils.data import ConcatDataset, DataLoader from ..engine.predictor import run - from ..utils.checkpointer import Checkpointer - from ..utils.download import download_to_tempfile from ..utils.plot import relevance_analysis_plot dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) - if weight.startswith("http"): - logger.info(f"Temporarily downloading '{weight}'...") - f = download_to_tempfile(weight, progress=True) - weight_fullpath = os.path.abspath(f.name) - else: - weight_fullpath = os.path.abspath(weight) - - checkpointer = Checkpointer(model) - checkpointer.load(weight_fullpath) + model = model.load_from_checkpoint(weight) # Logistic regressor weights if model.name == "logistic_regression": -- GitLab