diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 3e2ee50e664c9a00bb34d8ea5bee298c3fba893a..0663b60b9df2f089aa93ccbff18c992bc18c6749 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -105,7 +105,9 @@ class Densenet(pl.LightningModule): self.model_ft = models.densenet121(weights=weights) # Adapt output features - self.model_ft.classifier = torch.nn.Linear(1024, num_classes) + self.model_ft.classifier = torch.nn.Linear( + self.model_ft.classifier.in_features, num_classes + ) def forward(self, x): x = self.normalizer(x) # type: ignore