diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index 0b19b3d79fed91e21972daf337947b88e172f908..df224bc5f2eb973003b87a8170b0b53fd70b5aad 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -58,6 +58,9 @@ class Alexnet(pl.LightningModule):
     pretrained
         If set to True, loads pretrained model weights during initialization,
         else trains a new model.
+
+    num_classes
+        Number of outputs (classes) for this model.
     """
 
     def __init__(
@@ -68,6 +71,7 @@ class Alexnet(pl.LightningModule):
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         pretrained: bool = False,
+        num_classes: int = 1,
     ):
         super().__init__()
 
@@ -104,7 +108,7 @@ class Alexnet(pl.LightningModule):
 
         # Adapt output features
         self.model_ft.classifier[4] = torch.nn.Linear(4096, 512)
-        self.model_ft.classifier[6] = torch.nn.Linear(512, 1)
+        self.model_ft.classifier[6] = torch.nn.Linear(512, num_classes)
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index f6eb2cb670ba400c00b29b3122f4d4f128f37200..436d73ba55d5829cd886b37ebee882b80c41d2d8 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -56,6 +56,9 @@ class Densenet(pl.LightningModule):
     pretrained
         If set to True, loads pretrained model weights during initialization,
         else trains a new model.
+
+    num_classes
+        Number of outputs (classes) for this model.
     """
 
     def __init__(
@@ -66,6 +69,7 @@ class Densenet(pl.LightningModule):
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         pretrained: bool = False,
+        num_classes: int = 1,
     ):
         super().__init__()
 
@@ -101,9 +105,7 @@ class Densenet(pl.LightningModule):
         self.model_ft = models.densenet121(weights=weights)
 
         # Adapt output features
-        self.model_ft.classifier = torch.nn.Sequential(
-            torch.nn.Linear(1024, 256), torch.nn.Linear(256, 1)
-        )
+        self.model_ft.classifier = torch.nn.Linear(1024, num_classes)
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore