diff --git a/src/ptbench/configs/models/alexnet.py b/src/ptbench/configs/models/alexnet.py
index 5917db2146ac2a1efa4584cbc6126a3e88ae7f86..ecaee487d1878b1e73bdee8261f06ee475834106 100644
--- a/src/ptbench/configs/models/alexnet.py
+++ b/src/ptbench/configs/models/alexnet.py
@@ -4,19 +4,19 @@
 
 """AlexNet."""
 
+from torch import empty
 from torch.nn import BCEWithLogitsLoss
-from torch.optim import SGD
 
-from ...models.alexnet import build_alexnet
+from ...models.alexnet import Alexnet
 
 # config
-lr = 0.01
-
-# model
-model = build_alexnet(pretrained=False)
+optimizer_configs = {"lr": 0.01, "momentum": 0.1}
 
 # optimizer
-optimizer = SGD(model.parameters(), lr=lr, momentum=0.1)
-
+optimizer = "SGD"
 # criterion
-criterion = BCEWithLogitsLoss()
+criterion = BCEWithLogitsLoss(pos_weight=empty(1))
+criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
+
+# model
+model = Alexnet(criterion, criterion_valid, optimizer, optimizer_configs)
diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index ea096ecbdb51f05679d37594fa41d7c4788d8874..7aaaccb6a18e12e5e0d03c00e9f0974d470b31df 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -2,29 +2,45 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-from collections import OrderedDict
-
+import pytorch_lightning as pl
+import torch
 import torch.nn as nn
 import torchvision.models as models
 
 from .normalizer import TorchVisionNormalizer
 
 
-class Alexnet(nn.Module):
+class Alexnet(pl.LightningModule):
     """Alexnet module.
 
     Note: only usable with a normalized dataset
     """
 
-    def __init__(self, pretrained=False):
+    def __init__(
+        self,
+        criterion,
+        criterion_valid,
+        optimizer,
+        optimizer_configs,
+        pretrained=False,
+    ):
         super().__init__()
 
+        self.save_hyperparameters()
+
+        self.criterion = criterion
+        self.criterion_valid = criterion_valid
+
+        self.name = "AlexNet"
+
         # Load pretrained model
         weights = (
             None if pretrained is False else models.AlexNet_Weights.DEFAULT
         )
         self.model_ft = models.alexnet(weights=weights)
 
+        self.normalizer = TorchVisionNormalizer(nb_channels=1)
+
         # Adapt output features
         self.model_ft.classifier[4] = nn.Linear(4096, 512)
         self.model_ft.classifier[6] = nn.Linear(512, 1)
@@ -44,20 +60,59 @@ class Alexnet(nn.Module):
         tensor : :py:class:`torch.Tensor`
 
         """
-        return self.model_ft(x)
+        x = self.normalizer(x)
+        x = self.model_ft(x)
 
+        return x
 
-def build_alexnet(pretrained=False):
-    """Build Alexnet CNN.
+    def training_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
 
-    Returns
-    -------
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
 
-    module : :py:class:`torch.nn.Module`
-    """
-    model = Alexnet(pretrained=pretrained)
-    model = [("normalizer", TorchVisionNormalizer()), ("model", model)]
-    model = nn.Sequential(OrderedDict(model))
+        # Forward pass on the network
+        outputs = self(images)
+
+        training_loss = self.criterion(outputs, labels.double())
+
+        return {"loss": training_loss}
+
+    def validation_step(self, batch, batch_idx):
+        images = batch[1]
+        labels = batch[2]
+
+        # Increase label dimension if too low
+        # Allows single and multiclass usage
+        if labels.ndim == 1:
+            labels = torch.reshape(labels, (labels.shape[0], 1))
+
+        # data forwarding on the existing network
+        outputs = self(images)
+        validation_loss = self.criterion_valid(outputs, labels.double())
+
+        return {"validation_loss": validation_loss}
+
+    def predict_step(self, batch, batch_idx, grad_cams=False):
+        names = batch[0]
+        images = batch[1]
+
+        outputs = self(images)
+        probabilities = torch.sigmoid(outputs)
+
+        # necessary check for HED architecture that uses several outputs
+        # for loss calculation instead of just the last concatfuse block
+        if isinstance(outputs, list):
+            outputs = outputs[-1]
+
+        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
+
+    def configure_optimizers(self):
+        optimizer = getattr(torch.optim, self.hparams.optimizer)(
+            self.parameters(), **self.hparams.optimizer_configs
+        )
 
-    model.name = "AlexNet"
-    return model
+        return optimizer