From 6bb7ac555756a8f36bc2db88b679fa31c6423686 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 11 Apr 2023 14:32:09 +0200
Subject: [PATCH] Moved signs_to_tb model to lightning

---
 src/ptbench/configs/models/signs_to_tb.py | 19 +++---
 src/ptbench/models/signs_to_tb.py         | 80 +++++++++++++++++++----
 2 files changed, 79 insertions(+), 20 deletions(-)

diff --git a/src/ptbench/configs/models/signs_to_tb.py b/src/ptbench/configs/models/signs_to_tb.py
index 3bd552da..1ce89b12 100644
--- a/src/ptbench/configs/models/signs_to_tb.py
+++ b/src/ptbench/configs/models/signs_to_tb.py
@@ -8,19 +8,22 @@ Simple feedforward network taking radiological signs in output and
 predicting tuberculosis presence in output.
 """
 
+from torch import empty
 from torch.nn import BCEWithLogitsLoss
-from torch.optim import Adam
 
-from ...models.signs_to_tb import build_signs_to_tb
+from ...models.signs_to_tb import SignsToTB
 
 # config
-lr = 1e-2
-
-# model
-model = build_signs_to_tb(14, 10)
+optimizer_configs = {"lr": 1e-2}
 
 # optimizer
-optimizer = Adam(model.parameters(), lr=lr)
+optimizer = "Adam"
 
 # criterion
-criterion = BCEWithLogitsLoss()
+criterion = BCEWithLogitsLoss(pos_weight=empty(1))
+criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
+
+# model
+model = SignsToTB(
+    criterion, criterion_valid, optimizer, optimizer_configs, 14, 10
+)
diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py
index f3b3d5ea..653b590a 100644
--- a/src/ptbench/models/signs_to_tb.py
+++ b/src/ptbench/models/signs_to_tb.py
@@ -2,15 +2,31 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import pytorch_lightning as pl
 import torch
-import torch.nn as nn
 
 
-class SignsToTB(nn.Module):
+class SignsToTB(pl.LightningModule):
     """Radiological signs to Tuberculosis module."""
 
-    def __init__(self, input_size, hidden_size):
+    def __init__(
+        self,
+        criterion,
+        criterion_valid,
+        optimizer,
+        optimizer_configs,
+        input_size,
+        hidden_size,
+    ):
         super().__init__()
+
+        self.save_hyperparameters()
+
+        self.name = "signs_to_tb"
+
+        self.criterion = criterion
+        self.criterion_valid = criterion_valid
+
         self.input_size = input_size
         self.hidden_size = hidden_size
         self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
@@ -39,15 +55,55 @@ class SignsToTB(nn.Module):
 
         return output
 
+    def training_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # Forward pass on the network
+        outputs = self(images)
+
+        training_loss = self.criterion(outputs, labels.double())
+
+        return {"loss": training_loss}
+
+    def validation_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # data forwarding on the existing network
+        outputs = self(images)
+        validation_loss = self.criterion_valid(outputs, labels.double())
+
+        return {"validation_loss": validation_loss}
+
+    def predict_step(self, batch, batch_idx, grad_cams=False):
+        names = batch[0]
+        images = batch[1]
+
+        outputs = self(images)
+        probabilities = torch.sigmoid(outputs)
+
+        # necessary check for HED architecture that uses several outputs
+        # for loss calculation instead of just the last concatfuse block
+        if isinstance(outputs, list):
+            outputs = outputs[-1]
 
-def build_signs_to_tb(input_size, hidden_size):
-    """Build SignsToTB shallow model.
+        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
 
-    Returns
-    -------
+    def configure_optimizers(self):
+        # Dynamically instantiates the optimizer given the configs
+        optimizer = getattr(torch.optim, self.hparams.optimizer)(
+            self.parameters(), **self.hparams.optimizer_configs
+        )
 
-    module : :py:class:`torch.nn.Module`
-    """
-    model = SignsToTB(input_size, hidden_size)
-    model.name = "signs_to_tb"
-    return model
+        return optimizer
-- 
GitLab