From 6f12ad95f3c3cf6075a9ae5244237304bb960d49 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 20 Jul 2023 15:52:19 +0200
Subject: [PATCH] [models] Make alexnet and densenet potentially multi-class
 classifiers

---
 src/ptbench/models/alexnet.py  | 6 +++++-
 src/ptbench/models/densenet.py | 8 +++++---
 2 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index 0b19b3d7..df224bc5 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -58,6 +58,9 @@ class Alexnet(pl.LightningModule):
     pretrained
         If set to True, loads pretrained model weights during initialization,
         else trains a new model.
+
+    num_classes
+        Number of outputs (classes) for this model.
     """
 
     def __init__(
@@ -68,6 +71,7 @@ class Alexnet(pl.LightningModule):
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         pretrained: bool = False,
+        num_classes: int = 1,
     ):
         super().__init__()
 
@@ -104,7 +108,7 @@ class Alexnet(pl.LightningModule):
 
         # Adapt output features
         self.model_ft.classifier[4] = torch.nn.Linear(4096, 512)
-        self.model_ft.classifier[6] = torch.nn.Linear(512, 1)
+        self.model_ft.classifier[6] = torch.nn.Linear(512, num_classes)
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index f6eb2cb6..436d73ba 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -56,6 +56,9 @@ class Densenet(pl.LightningModule):
     pretrained
         If set to True, loads pretrained model weights during initialization,
         else trains a new model.
+
+    num_classes
+        Number of outputs (classes) for this model.
     """
 
     def __init__(
@@ -66,6 +69,7 @@ class Densenet(pl.LightningModule):
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         pretrained: bool = False,
+        num_classes: int = 1,
     ):
         super().__init__()
 
@@ -101,9 +105,7 @@ class Densenet(pl.LightningModule):
         self.model_ft = models.densenet121(weights=weights)
 
         # Adapt output features
-        self.model_ft.classifier = torch.nn.Sequential(
-            torch.nn.Linear(1024, 256), torch.nn.Linear(256, 1)
-        )
+        self.model_ft.classifier = torch.nn.Linear(1024, num_classes)
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore
-- 
GitLab