From 6bd5719e0a516f3aee2989a62cf9db88af179bcd Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 4 Oct 2023 19:40:59 +0200 Subject: [PATCH] [models] Report number of classes supported --- src/ptbench/models/alexnet.py | 3 ++- src/ptbench/models/densenet.py | 3 ++- src/ptbench/models/pasa.py | 9 ++++++++- src/ptbench/models/typing.py | 17 +++++++++++++++++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index ab66d1aa..5e5b3e7e 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -77,6 +77,7 @@ class Alexnet(pl.LightningModule): super().__init__() self.name = "alexnet" + self.num_classes = num_classes self.model_transforms = [ torchvision.transforms.Resize(512, antialias=True), @@ -107,7 +108,7 @@ class Alexnet(pl.LightningModule): # Adapt output features self.model_ft.classifier[4] = torch.nn.Linear(4096, 512) - self.model_ft.classifier[6] = torch.nn.Linear(512, num_classes) + self.model_ft.classifier[6] = torch.nn.Linear(512, self.num_classes) def forward(self, x): x = self.normalizer(x) # type: ignore diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index c7def1b5..333edb11 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -75,6 +75,7 @@ class Densenet(pl.LightningModule): super().__init__() self.name = "densenet-121" + self.num_classes = num_classes # image is probably large, resize first to get memory usage down self.model_transforms = [ @@ -106,7 +107,7 @@ class Densenet(pl.LightningModule): # Adapt output features self.model_ft.classifier = torch.nn.Linear( - self.model_ft.classifier.in_features, num_classes + self.model_ft.classifier.in_features, self.num_classes ) def forward(self, x): diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 112d7ef6..a7b8ae62 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -59,6 +59,9 @@ class Pasa(pl.LightningModule): augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. + + num_classes + Number of outputs (classes) for this model. """ def __init__( @@ -68,10 +71,12 @@ class Pasa(pl.LightningModule): optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], + num_classes: int = 1, ): super().__init__() self.name = "pasa" + self.num_classes = num_classes # image is probably large, resize first to get memory usage down self.model_transforms = [ @@ -146,7 +151,9 @@ class Pasa(pl.LightningModule): self.pool2d = torch.nn.MaxPool2d( (3, 3), (2, 2) ) # Pool after conv. block - self.dense = torch.nn.Linear(80, 1) # Fully connected layer + self.dense = torch.nn.Linear( + 80, self.num_classes + ) # Fully connected layer def forward(self, x): x = self.normalizer(x) # type: ignore diff --git a/src/ptbench/models/typing.py b/src/ptbench/models/typing.py index 3eb9017c..3c64a873 100644 --- a/src/ptbench/models/typing.py +++ b/src/ptbench/models/typing.py @@ -25,3 +25,20 @@ MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[ str, typing.Sequence[MultiClassPrediction] ] """A series of predictions for different database splits.""" + +VisualisationType: typing.TypeAlias = typing.Literal[ + "ablationcam", + "eigencam", + "eigengradcam", + "fullgrad", + "gradcam", + "gradcamelementwise", + "gradcam++", + "gradcamplusplus", + "hirescam", + "layercam", + "randomcam", + "scorecam", + "xgradcam", +] +"""Supported visualisation types.""" -- GitLab