diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 3e2ee50e664c9a00bb34d8ea5bee298c3fba893a..0663b60b9df2f089aa93ccbff18c992bc18c6749 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -105,7 +105,9 @@ class Densenet(pl.LightningModule):
         self.model_ft = models.densenet121(weights=weights)
 
         # Adapt output features
-        self.model_ft.classifier = torch.nn.Linear(1024, num_classes)
+        self.model_ft.classifier = torch.nn.Linear(
+            self.model_ft.classifier.in_features, num_classes
+        )
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore