diff --git a/src/ptbench/configs/models_datasets/densenet_rs.py b/src/ptbench/configs/models_datasets/densenet_rs.py
index 57fc7b78674779d50fe16002dc154b0dbe13e33a..714404bf173c8868e4c67e49f257a1b221e540d7 100644
--- a/src/ptbench/configs/models_datasets/densenet_rs.py
+++ b/src/ptbench/configs/models_datasets/densenet_rs.py
@@ -7,10 +7,10 @@
 A Densenet121 model for radiological extraction
 """
 
+from torch import empty
 from torch.nn import BCEWithLogitsLoss
-from torch.optim import Adam
 
-from ...models.densenet_rs import build_densenetrs
+from ...models.densenet_rs import DensenetRS
 
 # Import the default protocol if none is available
 if "dataset" not in locals():
@@ -19,16 +19,14 @@ if "dataset" not in locals():
     dataset = default.dataset
 
 # config
-lr = 1e-4
-
-# model
-model = build_densenetrs()
+optimizer_configs = {"lr": 1e-4}
 
 # optimizer
-optimizer = Adam(
-    filter(lambda p: p.requires_grad, model.model.model_ft.parameters()), lr=lr
-)
+optimizer = "Adam"
 
 # criterion
-criterion = BCEWithLogitsLoss()
-criterion_valid = BCEWithLogitsLoss()
+criterion = BCEWithLogitsLoss(pos_weight=empty(1))
+criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
+
+# model
+model = DensenetRS(criterion, criterion_valid, optimizer, optimizer_configs)
diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py
index c4448fbca8d97a60ccc041cc847bd79a0fd58056..997516a02bcdb5f2b7fdbe04e10fd48077d51092 100644
--- a/src/ptbench/models/densenet_rs.py
+++ b/src/ptbench/models/densenet_rs.py
@@ -2,20 +2,35 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-from collections import OrderedDict
-
+import pytorch_lightning as pl
+import torch
 import torch.nn as nn
 import torchvision.models as models
 
 from .normalizer import TorchVisionNormalizer
 
 
-class DensenetRS(nn.Module):
+class DensenetRS(pl.LightningModule):
     """Densenet121 module for radiological extraction."""
 
-    def __init__(self):
+    def __init__(
+        self,
+        criterion,
+        criterion_valid,
+        optimizer,
+        optimizer_configs,
+    ):
         super().__init__()
 
+        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
+
+        self.name = "DensenetRS"
+
+        self.criterion = criterion
+        self.criterion_valid = criterion_valid
+
+        self.normalizer = TorchVisionNormalizer()
+
         # Load pretrained model
         self.model_ft = models.densenet121(
             weights=models.DenseNet121_Weights.DEFAULT
@@ -40,20 +55,61 @@ class DensenetRS(nn.Module):
         tensor : :py:class:`torch.Tensor`
 
         """
-        return self.model_ft(x)
 
+        x = self.normalizer(x)
+        x = self.model_ft(x)
+        return x
+
+    def training_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # Forward pass on the network
+        outputs = self(images)
+
+        training_loss = self.criterion(outputs, labels.double())
 
-def build_densenetrs():
-    """Build DensenetRS CNN.
+        return {"loss": training_loss}
 
-    Returns
-    -------
+    def validation_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
 
-    module : :py:class:`torch.nn.Module`
-    """
-    model = DensenetRS()
-    model = [("normalizer", TorchVisionNormalizer()), ("model", model)]
-    model = nn.Sequential(OrderedDict(model))
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # data forwarding on the existing network
+        outputs = self(images)
+        validation_loss = self.criterion_valid(outputs, labels.double())
+
+        return {"validation_loss": validation_loss}
+
+    def predict_step(self, batch, batch_idx, grad_cams=False):
+        names = batch[0]
+        images = batch[1]
+
+        outputs = self(images)
+        probabilities = torch.sigmoid(outputs)
+
+        # necessary check for HED architecture that uses several outputs
+        # for loss calculation instead of just the last concatfuse block
+        if isinstance(outputs, list):
+            outputs = outputs[-1]
+
+        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
+
+    def configure_optimizers(self):
+        # Dynamically instantiates the optimizer given the configs
+        optimizer = getattr(torch.optim, self.hparams.optimizer)(
+            filter(lambda p: p.requires_grad, self.model_ft.parameters()),
+            **self.hparams.optimizer_configs,
+        )
 
-    model.name = "DensenetRS"
-    return model
+        return optimizer