From e8d70f15b5ef4c541d034c39735a7bdc946c078d Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 11 Apr 2023 11:48:44 +0200
Subject: [PATCH] Moved prediction to lightning

Some DensenetRS-specific code has been removed from predictor.py, it
will have to be re-added directly inside the forward of the DensenetRS model.
---
 src/ptbench/engine/callbacks.py |  18 ++-
 src/ptbench/engine/predictor.py | 237 +++-----------------------------
 src/ptbench/models/densenet.py  |  15 +-
 src/ptbench/models/pasa.py      |  14 ++
 src/ptbench/scripts/predict.py  |  13 +-
 5 files changed, 55 insertions(+), 242 deletions(-)

diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index 20d457d1..a01a7e19 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -91,14 +91,12 @@ class PredictionsWriter(BasePredictionWriter):
             logwriter = csv.DictWriter(logfile, fieldnames=self.logfile_fields)
             logwriter.writeheader()
 
-            # We should only get a single epoch here
-            for epoch in predictions:
-                for prediction in epoch:
-                    logwriter.writerow(
-                        {
-                            "filename": prediction[0],
-                            "likelihood": prediction[1].numpy(),
-                            "ground_truth": prediction[2].numpy(),
-                        }
-                    )
+            for prediction in predictions:
+                logwriter.writerow(
+                    {
+                        "filename": prediction[0],
+                        "likelihood": prediction[1].numpy(),
+                        "ground_truth": prediction[2].numpy(),
+                    }
+                )
             logfile.flush()
diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py
index 6c4dd4af..31c3e4c6 100644
--- a/src/ptbench/engine/predictor.py
+++ b/src/ptbench/engine/predictor.py
@@ -2,44 +2,15 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-import csv
-import datetime
 import logging
 import os
-import shutil
-import time
 
-import matplotlib.pyplot as plt
-import numpy
-import PIL
-import torch
+from pytorch_lightning import Trainer
 
-from matplotlib.gridspec import GridSpec
-from matplotlib.patches import Rectangle
-from torchvision import transforms
-from tqdm import tqdm
-
-from ..utils.grad_cams import GradCAM
+from .callbacks import PredictionsWriter
 
 logger = logging.getLogger(__name__)
 
-colors = [
-    [(47, 79, 79), "Cardiomegaly"],
-    [(255, 0, 0), "Emphysema"],
-    [(0, 128, 0), "Pleural effusion"],
-    [(0, 0, 128), "Hernia"],
-    [(255, 84, 0), "Infiltration"],
-    [(222, 184, 135), "Mass"],
-    [(0, 255, 0), "Nodule"],
-    [(0, 191, 255), "Atelectasis"],
-    [(0, 0, 255), "Pneumothorax"],
-    [(255, 0, 255), "Pleural thickening"],
-    [(255, 255, 0), "Pneumonia"],
-    [(126, 0, 255), "Fibrosis"],
-    [(255, 20, 147), "Edema"],
-    [(0, 255, 180), "Consolidation"],
-]
-
 
 def run(model, data_loader, name, device, output_folder, grad_cams=False):
     """Runs inference on input data, outputs HDF5 files with predictions.
@@ -82,192 +53,18 @@ def run(model, data_loader, name, device, output_folder, grad_cams=False):
     logfile_name = os.path.join(output_folder, "predictions.csv")
     logfile_fields = ("filename", "likelihood", "ground_truth")
 
-    if os.path.exists(logfile_name):
-        backup = logfile_name + "~"
-        if os.path.exists(backup):
-            os.unlink(backup)
-        shutil.move(logfile_name, backup)
-
-    if grad_cams:
-        grad_folder = os.path.join(output_folder, "cams")
-        logger.info(f"Grad cams folder: {grad_folder}")
-        os.makedirs(grad_folder, exist_ok=True)
-
-    with open(logfile_name, "a+", newline="") as logfile:
-        logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
-
-        logwriter.writeheader()
-
-        model.eval()  # set evaluation mode
-        model.to(device)  # set/cast parameters to device
-
-        # Setup timers
-        start_total_time = time.time()
-        times = []
-        len_samples = []
-
-        all_predictions = []
-
-        for samples in tqdm(
-            data_loader,
-            desc="batches",
-            leave=False,
-            disable=None,
-        ):
-            names = samples[0]
-            images = samples[1].to(
-                device=device, non_blocking=torch.cuda.is_available()
-            )
-
-            # Gradcams generation
-            allowed_models = ["DensenetRS"]
-            if grad_cams and model.name in allowed_models:
-                gcam = GradCAM(model=model)
-                probs, ids = gcam.forward(images)
-
-                # To store signs overlays
-                cams_img = dict()
-
-                # Top k number of radiological signs for which we generate cams
-                topk = 1
-
-                for i in range(topk):
-                    # Keep only "positive" signs
-                    if probs[:, [i]] > 0.5:
-                        # Grad-CAM
-                        b = ids[:, [i]]
-                        gcam.backward(ids=ids[:, [i]])
-                        regions = gcam.generate(
-                            target_layer="model_ft.features.denseblock4.denselayer16.conv2"
-                        )
-
-                        for j in range(len(images)):
-                            current_cam = regions[j, 0].cpu().numpy()
-                            current_cam[current_cam < 0.75] = 0.0
-                            current_cam[current_cam >= 0.75] = 1.0
-                            current_cam = PIL.Image.fromarray(
-                                numpy.uint8(current_cam * 255), "L"
-                            )
-                            cams_img[b.item()] = [
-                                current_cam,
-                                round(probs[:, [i]].item(), 2),
-                            ]
-
-                if len(cams_img) > 0:
-                    # Convert original image tensor into PIL Image
-                    original_image = transforms.ToPILImage(mode="RGB")(
-                        images[0]
-                    )
-
-                    for sign_id, label_prob in cams_img.items():
-                        label = label_prob[0]
-
-                        # Create the colored overlay for current sign
-                        colored_sign = PIL.ImageOps.colorize(
-                            label.convert("L"), (0, 0, 0), colors[sign_id][0]
-                        )
-
-                        # blend image and label together - first blend to get signs drawn with a
-                        # slight "label_color" tone on top, then composite with original image, to
-                        # avoid loosing brightness.
-                        retval = PIL.Image.blend(
-                            original_image, colored_sign, 0.5
-                        )
-                        composite_mask = PIL.ImageOps.invert(label.convert("L"))
-                        original_image = PIL.Image.composite(
-                            original_image, retval, composite_mask
-                        )
-
-                    handles = []
-                    labels = []
-                    for i, v in enumerate(colors):
-                        # If sign present on image
-                        if cams_img.get(i) is not None:
-                            handles.append(
-                                Rectangle(
-                                    (0, 0),
-                                    1,
-                                    1,
-                                    color=tuple(v / 255 for v in v[0]),
-                                )
-                            )
-                            labels.append(
-                                v[1] + " (" + str(cams_img[i][1]) + ")"
-                            )
-
-                    gs = GridSpec(6, 1)
-                    fig = plt.figure(figsize=(10, 11))
-                    ax1 = fig.add_subplot(gs[:-1, :])  # For the plot
-                    ax2 = fig.add_subplot(gs[-1, :])  # For the legend
-
-                    ax1.imshow(original_image)
-                    ax1.axis("off")
-                    ax2.legend(
-                        handles, labels, mode="expand", ncol=3, frameon=False
-                    )
-                    ax2.axis("off")
-
-                    original_filename = (
-                        samples[0][0].split("/")[-1].split(".")[0]
-                    )
-                    cam_filename = os.path.join(
-                        grad_folder, original_filename + "_cam.png"
-                    )
-                    fig.savefig(cam_filename)
-
-            with torch.no_grad():
-                start_time = time.perf_counter()
-                outputs = model(images)
-                probabilities = torch.sigmoid(outputs)
-
-                # necessary check for HED architecture that uses several outputs
-                # for loss calculation instead of just the last concatfuse block
-                if isinstance(outputs, list):
-                    outputs = outputs[-1]
-
-                # predictions = sigmoid(outputs)
-
-                batch_time = time.perf_counter() - start_time
-                times.append(batch_time)
-                len_samples.append(len(images))
-
-                logdata = (
-                    ("filename", f"{names[0]}"),
-                    (
-                        "likelihood",
-                        f"{torch.flatten(probabilities).data.cpu().numpy()}",
-                    ),
-                    (
-                        "ground_truth",
-                        f"{torch.flatten(samples[2]).data.cpu().numpy()}",
-                    ),
-                )
-
-                logwriter.writerow(dict(k for k in logdata))
-                logfile.flush()
-                tqdm.write(" | ".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
-
-                # Keep prediction for relevance analysis
-                all_predictions.append(
-                    [
-                        names[0],
-                        torch.flatten(probabilities).data.cpu().numpy(),
-                        torch.flatten(samples[2]).data.cpu().numpy(),
-                    ]
-                )
-
-        # report operational summary
-        total_time = datetime.timedelta(
-            seconds=int(time.time() - start_total_time)
-        )
-        logger.info(f"Total time: {total_time}")
-
-        average_batch_time = numpy.mean(times)
-        logger.info(f"Average batch time: {average_batch_time:g}s")
-
-        average_image_time = numpy.sum(
-            numpy.array(times) * len_samples
-        ) / float(sum(len_samples))
-        logger.info(f"Average image time: {average_image_time:g}s")
-
-        return all_predictions
+    trainer = Trainer(
+        accelerator="auto",
+        devices="auto",
+        callbacks=[
+            PredictionsWriter(
+                logfile_name=logfile_name,
+                logfile_fields=logfile_fields,
+                write_interval="epoch",
+            ),
+        ],
+    )
+
+    all_predictions = trainer.predict(model, data_loader)
+
+    return all_predictions
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 33476d42..17373b79 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -60,7 +60,6 @@ class Densenet(pl.LightningModule):
         tensor : :py:class:`torch.Tensor`
 
         """
-
         x = self.normalizer(x)
 
         x = self.model_ft(x)
@@ -98,6 +97,20 @@ class Densenet(pl.LightningModule):
 
         return {"validation_loss": validation_loss}
 
+    def predict_step(self, batch, batch_idx, grad_cams=False):
+        names = batch[0]
+        images = batch[1]
+
+        outputs = self(images)
+        probabilities = torch.sigmoid(outputs)
+
+        # necessary check for HED architecture that uses several outputs
+        # for loss calculation instead of just the last concatfuse block
+        if isinstance(outputs, list):
+            outputs = outputs[-1]
+
+        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
+
     def configure_optimizers(self):
         # Dynamically instantiates the optimizer given the configs
         optimizer = getattr(torch.optim, self.hparams.optimizer)(
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 4fd816e1..b31fa21d 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -186,6 +186,20 @@ class PASA(pl.LightningModule):
 
         return {"validation_loss": validation_loss}
 
+    def predict_step(self, batch, batch_idx, grad_cams=False):
+        names = batch[0]
+        images = batch[1]
+
+        outputs = self(images)
+        probabilities = torch.sigmoid(outputs)
+
+        # necessary check for HED architecture that uses several outputs
+        # for loss calculation instead of just the last concatfuse block
+        if isinstance(outputs, list):
+            outputs = outputs[-1]
+
+        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
+
     def configure_optimizers(self):
         # Dynamically instantiates the optimizer given the configs
         optimizer = getattr(torch.optim, self.hparams.optimizer)(
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index 51275fc4..82939f25 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -63,6 +63,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 )
 @click.option(
     "--device",
+    "-d",
     help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
     show_default=True,
     required=True,
@@ -117,21 +118,11 @@ def predict(
     from torch.utils.data import ConcatDataset, DataLoader
 
     from ..engine.predictor import run
-    from ..utils.checkpointer import Checkpointer
-    from ..utils.download import download_to_tempfile
     from ..utils.plot import relevance_analysis_plot
 
     dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
 
-    if weight.startswith("http"):
-        logger.info(f"Temporarily downloading '{weight}'...")
-        f = download_to_tempfile(weight, progress=True)
-        weight_fullpath = os.path.abspath(f.name)
-    else:
-        weight_fullpath = os.path.abspath(weight)
-
-    checkpointer = Checkpointer(model)
-    checkpointer.load(weight_fullpath)
+    model = model.load_from_checkpoint(weight)
 
     # Logistic regressor weights
     if model.name == "logistic_regression":
-- 
GitLab