From 6290c1bd893c1832ecb2c0e50950b2ece27fb4e6 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 11 Apr 2023 08:18:18 +0200
Subject: [PATCH] Moved densenet to lightning

---
 src/ptbench/configs/models/densenet.py | 19 +++---
 src/ptbench/models/densenet.py         | 82 +++++++++++++++++++-------
 2 files changed, 73 insertions(+), 28 deletions(-)

diff --git a/src/ptbench/configs/models/densenet.py b/src/ptbench/configs/models/densenet.py
index 2017786a..67594908 100644
--- a/src/ptbench/configs/models/densenet.py
+++ b/src/ptbench/configs/models/densenet.py
@@ -4,19 +4,22 @@
 
 """DenseNet."""
 
+from torch import empty
 from torch.nn import BCEWithLogitsLoss
-from torch.optim import Adam
 
-from ...models.densenet import build_densenet
+from ...models.densenet import Densenet
 
 # config
-lr = 0.0001
-
-# model
-model = build_densenet(pretrained=False)
+optimizer_configs = {"lr": 0.0001}
 
 # optimizer
-optimizer = Adam(model.parameters(), lr=lr)
+optimizer = "Adam"
 
 # criterion
-criterion = BCEWithLogitsLoss()
+criterion = BCEWithLogitsLoss(pos_weight=empty(1))
+criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
+
+# model
+model = Densenet(
+    criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False
+)
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 7a98acac..33476d42 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -2,23 +2,40 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-from collections import OrderedDict
-
+import pytorch_lightning as pl
+import torch
 import torch.nn as nn
 import torchvision.models as models
 
 from .normalizer import TorchVisionNormalizer
 
 
-class Densenet(nn.Module):
+class Densenet(pl.LightningModule):
     """Densenet module.
 
     Note: only usable with a normalized dataset
     """
 
-    def __init__(self, pretrained=False):
+    def __init__(
+        self,
+        criterion,
+        criterion_valid,
+        optimizer,
+        optimizer_params,
+        pretrained=False,
+        nb_channels=3,
+    ):
         super().__init__()
 
+        self.save_hyperparameters()
+
+        self.name = "Densenet"
+
+        self.criterion = criterion
+        self.criterion_valid = criterion_valid
+
+        self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels)
+
         # Load pretrained model
         weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT
         self.model_ft = models.densenet121(weights=weights)
@@ -43,23 +60,48 @@ class Densenet(nn.Module):
         tensor : :py:class:`torch.Tensor`
 
         """
-        return self.model_ft(x)
 
+        x = self.normalizer(x)
 
-def build_densenet(pretrained=False, nb_channels=3):
-    """Build Densenet CNN.
+        x = self.model_ft(x)
 
-    Returns
-    -------
+        return x
 
-    module : :py:class:`torch.nn.Module`
-    """
-    model = Densenet(pretrained=pretrained)
-    model = [
-        ("normalizer", TorchVisionNormalizer(nb_channels=nb_channels)),
-        ("model", model),
-    ]
-    model = nn.Sequential(OrderedDict(model))
-
-    model.name = "Densenet"
-    return model
+    def training_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # Forward pass on the network
+        outputs = self(images)
+
+        training_loss = self.criterion(outputs, labels.double())
+
+        return {"loss": training_loss}
+
+    def validation_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # data forwarding on the existing network
+        outputs = self(images)
+        validation_loss = self.criterion_valid(outputs, labels.double())
+
+        return {"validation_loss": validation_loss}
+
+    def configure_optimizers(self):
+        # Dynamically instantiates the optimizer given the configs
+        optimizer = getattr(torch.optim, self.hparams.optimizer)(
+            self.parameters(), **self.hparams.optimizer_params
+        )
+
+        return optimizer
-- 
GitLab