From 262e52a749516253756536a6af895429a623048e Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 2 Aug 2023 12:22:01 +0200 Subject: [PATCH] [models.densenet] Use internal model constant --- src/ptbench/models/densenet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 3e2ee50e..0663b60b 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -105,7 +105,9 @@ class Densenet(pl.LightningModule): self.model_ft = models.densenet121(weights=weights) # Adapt output features - self.model_ft.classifier = torch.nn.Linear(1024, num_classes) + self.model_ft.classifier = torch.nn.Linear( + self.model_ft.classifier.in_features, num_classes + ) def forward(self, x): x = self.normalizer(x) # type: ignore -- GitLab