From 262e52a749516253756536a6af895429a623048e Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 2 Aug 2023 12:22:01 +0200
Subject: [PATCH] [models.densenet] Use internal model constant

---
 src/ptbench/models/densenet.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 3e2ee50e..0663b60b 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -105,7 +105,9 @@ class Densenet(pl.LightningModule):
         self.model_ft = models.densenet121(weights=weights)
 
         # Adapt output features
-        self.model_ft.classifier = torch.nn.Linear(1024, num_classes)
+        self.model_ft.classifier = torch.nn.Linear(
+            self.model_ft.classifier.in_features, num_classes
+        )
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore
-- 
GitLab