diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index ab66d1aa3ab183e29017e32d8957d0320381acda..5e5b3e7eee3d810aaa4c8dfed31ad3d9fa5414b5 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -77,6 +77,7 @@ class Alexnet(pl.LightningModule):
         super().__init__()
 
         self.name = "alexnet"
+        self.num_classes = num_classes
 
         self.model_transforms = [
             torchvision.transforms.Resize(512, antialias=True),
@@ -107,7 +108,7 @@ class Alexnet(pl.LightningModule):
 
         # Adapt output features
         self.model_ft.classifier[4] = torch.nn.Linear(4096, 512)
-        self.model_ft.classifier[6] = torch.nn.Linear(512, num_classes)
+        self.model_ft.classifier[6] = torch.nn.Linear(512, self.num_classes)
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index c7def1b5b86a4a4e9e9ab474c7aa93140e0707b7..333edb11e2112db9ef1f1bf37b2d333f5b418de6 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -75,6 +75,7 @@ class Densenet(pl.LightningModule):
         super().__init__()
 
         self.name = "densenet-121"
+        self.num_classes = num_classes
 
         # image is probably large, resize first to get memory usage down
         self.model_transforms = [
@@ -106,7 +107,7 @@ class Densenet(pl.LightningModule):
 
         # Adapt output features
         self.model_ft.classifier = torch.nn.Linear(
-            self.model_ft.classifier.in_features, num_classes
+            self.model_ft.classifier.in_features, self.num_classes
         )
 
     def forward(self, x):
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 112d7ef6cebacbb514725fe2946c22cd0c452c92..a7b8ae62c0c56bc2c5172058fdd3382c987514a1 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -59,6 +59,9 @@ class Pasa(pl.LightningModule):
     augmentation_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
+
+    num_classes
+        Number of outputs (classes) for this model.
     """
 
     def __init__(
@@ -68,10 +71,12 @@ class Pasa(pl.LightningModule):
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
+        num_classes: int = 1,
     ):
         super().__init__()
 
         self.name = "pasa"
+        self.num_classes = num_classes
 
         # image is probably large, resize first to get memory usage down
         self.model_transforms = [
@@ -146,7 +151,9 @@ class Pasa(pl.LightningModule):
         self.pool2d = torch.nn.MaxPool2d(
             (3, 3), (2, 2)
         )  # Pool after conv. block
-        self.dense = torch.nn.Linear(80, 1)  # Fully connected layer
+        self.dense = torch.nn.Linear(
+            80, self.num_classes
+        )  # Fully connected layer
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore
diff --git a/src/ptbench/models/typing.py b/src/ptbench/models/typing.py
index 3eb9017c1e7bbbf860cf7be67622cdfe5db513df..3c64a873dc6c539b3292d9e35495f1052c2cebd1 100644
--- a/src/ptbench/models/typing.py
+++ b/src/ptbench/models/typing.py
@@ -25,3 +25,20 @@ MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[
     str, typing.Sequence[MultiClassPrediction]
 ]
 """A series of predictions for different database splits."""
+
+VisualisationType: typing.TypeAlias = typing.Literal[
+    "ablationcam",
+    "eigencam",
+    "eigengradcam",
+    "fullgrad",
+    "gradcam",
+    "gradcamelementwise",
+    "gradcam++",
+    "gradcamplusplus",
+    "hirescam",
+    "layercam",
+    "randomcam",
+    "scorecam",
+    "xgradcam",
+]
+"""Supported visualisation types."""