From 6bd5719e0a516f3aee2989a62cf9db88af179bcd Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 4 Oct 2023 19:40:59 +0200
Subject: [PATCH] [models] Report number of classes supported

---
 src/ptbench/models/alexnet.py  |  3 ++-
 src/ptbench/models/densenet.py |  3 ++-
 src/ptbench/models/pasa.py     |  9 ++++++++-
 src/ptbench/models/typing.py   | 17 +++++++++++++++++
 4 files changed, 29 insertions(+), 3 deletions(-)

diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index ab66d1aa..5e5b3e7e 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -77,6 +77,7 @@ class Alexnet(pl.LightningModule):
         super().__init__()
 
         self.name = "alexnet"
+        self.num_classes = num_classes
 
         self.model_transforms = [
             torchvision.transforms.Resize(512, antialias=True),
@@ -107,7 +108,7 @@ class Alexnet(pl.LightningModule):
 
         # Adapt output features
         self.model_ft.classifier[4] = torch.nn.Linear(4096, 512)
-        self.model_ft.classifier[6] = torch.nn.Linear(512, num_classes)
+        self.model_ft.classifier[6] = torch.nn.Linear(512, self.num_classes)
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index c7def1b5..333edb11 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -75,6 +75,7 @@ class Densenet(pl.LightningModule):
         super().__init__()
 
         self.name = "densenet-121"
+        self.num_classes = num_classes
 
         # image is probably large, resize first to get memory usage down
         self.model_transforms = [
@@ -106,7 +107,7 @@ class Densenet(pl.LightningModule):
 
         # Adapt output features
         self.model_ft.classifier = torch.nn.Linear(
-            self.model_ft.classifier.in_features, num_classes
+            self.model_ft.classifier.in_features, self.num_classes
         )
 
     def forward(self, x):
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 112d7ef6..a7b8ae62 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -59,6 +59,9 @@ class Pasa(pl.LightningModule):
     augmentation_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
+
+    num_classes
+        Number of outputs (classes) for this model.
     """
 
     def __init__(
@@ -68,10 +71,12 @@ class Pasa(pl.LightningModule):
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
+        num_classes: int = 1,
     ):
         super().__init__()
 
         self.name = "pasa"
+        self.num_classes = num_classes
 
         # image is probably large, resize first to get memory usage down
         self.model_transforms = [
@@ -146,7 +151,9 @@ class Pasa(pl.LightningModule):
         self.pool2d = torch.nn.MaxPool2d(
             (3, 3), (2, 2)
         )  # Pool after conv. block
-        self.dense = torch.nn.Linear(80, 1)  # Fully connected layer
+        self.dense = torch.nn.Linear(
+            80, self.num_classes
+        )  # Fully connected layer
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore
diff --git a/src/ptbench/models/typing.py b/src/ptbench/models/typing.py
index 3eb9017c..3c64a873 100644
--- a/src/ptbench/models/typing.py
+++ b/src/ptbench/models/typing.py
@@ -25,3 +25,20 @@ MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[
     str, typing.Sequence[MultiClassPrediction]
 ]
 """A series of predictions for different database splits."""
+
+VisualisationType: typing.TypeAlias = typing.Literal[
+    "ablationcam",
+    "eigencam",
+    "eigengradcam",
+    "fullgrad",
+    "gradcam",
+    "gradcamelementwise",
+    "gradcam++",
+    "gradcamplusplus",
+    "hirescam",
+    "layercam",
+    "randomcam",
+    "scorecam",
+    "xgradcam",
+]
+"""Supported visualisation types."""
-- 
GitLab