diff --git a/src/ptbench/configs/models/logistic_regression.py b/src/ptbench/configs/models/logistic_regression.py
index b93935b471b3f34a371c81b7659ae12309bee05b..145dddd7a3a9c6e0b0574f3c6358f6f9f51056a0 100644
--- a/src/ptbench/configs/models/logistic_regression.py
+++ b/src/ptbench/configs/models/logistic_regression.py
@@ -7,20 +7,23 @@
 Simple feedforward network taking radiological signs in output and
 predicting tuberculosis presence in output.
 """
-
+from torch import empty
 from torch.nn import BCEWithLogitsLoss
-from torch.optim import Adam
 
-from ...models.logistic_regression import build_logistic_regression
+from ...models.logistic_regression import LogisticRegression
 
 # config
-lr = 1e-2
-
-# model
-model = build_logistic_regression(14)
+optimizer_configs = {"lr": 1e-2}
+input_size = 14
 
 # optimizer
-optimizer = Adam(model.parameters(), lr=lr)
+optimizer = "Adam"
 
 # criterion
-criterion = BCEWithLogitsLoss()
+criterion = BCEWithLogitsLoss(pos_weight=empty(1))
+criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
+
+# model
+model = LogisticRegression(
+    criterion, criterion_valid, optimizer, optimizer_configs, input_size
+)
diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py
index 7e7818c71d8ebdb636a9b863965974cb71f91fba..684155b41e2e92a4903ae7ffbbe67dce74d93d70 100644
--- a/src/ptbench/models/logistic_regression.py
+++ b/src/ptbench/models/logistic_regression.py
@@ -2,16 +2,32 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import pytorch_lightning as pl
 import torch
 import torch.nn as nn
 
 
-class LogisticRegression(nn.Module):
+class LogisticRegression(pl.LightningModule):
     """Radiological signs to Tuberculosis module."""
 
-    def __init__(self, input_size):
+    def __init__(
+        self,
+        criterion,
+        criterion_valid,
+        optimizer,
+        optimizer_configs,
+        input_size,
+    ):
         super().__init__()
-        self.linear = torch.nn.Linear(input_size, 1)
+
+        self.save_hyperparameters()
+
+        self.criterion = criterion
+        self.criterion_valid = criterion_valid
+
+        self.name = "logistic_regression"
+
+        self.linear = nn.Linear(input_size, 1)
 
     def forward(self, x):
         """
@@ -32,15 +48,54 @@ class LogisticRegression(nn.Module):
 
         return output
 
+    def training_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # Forward pass on the network
+        outputs = self(images)
+
+        training_loss = self.criterion(outputs, labels.double())
+
+        return {"loss": training_loss}
+
+    def validation_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # data forwarding on the existing network
+        outputs = self(images)
+        validation_loss = self.criterion_valid(outputs, labels.double())
+
+        return {"validation_loss": validation_loss}
+
+    def predict_step(self, batch, batch_idx, grad_cams=False):
+        names = batch[0]
+        images = batch[1]
+
+        outputs = self(images)
+        probabilities = torch.sigmoid(outputs)
+
+        # necessary check for HED architecture that uses several outputs
+        # for loss calculation instead of just the last concatfuse block
+        if isinstance(outputs, list):
+            outputs = outputs[-1]
 
-def build_logistic_regression(input_size):
-    """Build logistic regression module.
+        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
 
-    Returns
-    -------
+    def configure_optimizers(self):
+        optimizer = getattr(torch.optim, self.hparams.optimizer)(
+            self.parameters(), **self.hparams.optimizer_configs
+        )
 
-    module : :py:class:`torch.nn.Module`
-    """
-    model = LogisticRegression(input_size)
-    model.name = "logistic_regression"
-    return model
+        return optimizer