Skip to content
Snippets Groups Projects
Commit 6bd5719e authored by André Anjos's avatar André Anjos :speech_balloon: Committed by Daniel CARRON
Browse files

[models] Report number of classes supported

parent 32cae111
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
......@@ -77,6 +77,7 @@ class Alexnet(pl.LightningModule):
super().__init__()
self.name = "alexnet"
self.num_classes = num_classes
self.model_transforms = [
torchvision.transforms.Resize(512, antialias=True),
......@@ -107,7 +108,7 @@ class Alexnet(pl.LightningModule):
# Adapt output features
self.model_ft.classifier[4] = torch.nn.Linear(4096, 512)
self.model_ft.classifier[6] = torch.nn.Linear(512, num_classes)
self.model_ft.classifier[6] = torch.nn.Linear(512, self.num_classes)
def forward(self, x):
x = self.normalizer(x) # type: ignore
......
......@@ -75,6 +75,7 @@ class Densenet(pl.LightningModule):
super().__init__()
self.name = "densenet-121"
self.num_classes = num_classes
# image is probably large, resize first to get memory usage down
self.model_transforms = [
......@@ -106,7 +107,7 @@ class Densenet(pl.LightningModule):
# Adapt output features
self.model_ft.classifier = torch.nn.Linear(
self.model_ft.classifier.in_features, num_classes
self.model_ft.classifier.in_features, self.num_classes
)
def forward(self, x):
......
......@@ -59,6 +59,9 @@ class Pasa(pl.LightningModule):
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
......@@ -68,10 +71,12 @@ class Pasa(pl.LightningModule):
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__()
self.name = "pasa"
self.num_classes = num_classes
# image is probably large, resize first to get memory usage down
self.model_transforms = [
......@@ -146,7 +151,9 @@ class Pasa(pl.LightningModule):
self.pool2d = torch.nn.MaxPool2d(
(3, 3), (2, 2)
) # Pool after conv. block
self.dense = torch.nn.Linear(80, 1) # Fully connected layer
self.dense = torch.nn.Linear(
80, self.num_classes
) # Fully connected layer
def forward(self, x):
x = self.normalizer(x) # type: ignore
......
......@@ -25,3 +25,20 @@ MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[
str, typing.Sequence[MultiClassPrediction]
]
"""A series of predictions for different database splits."""
VisualisationType: typing.TypeAlias = typing.Literal[
"ablationcam",
"eigencam",
"eigengradcam",
"fullgrad",
"gradcam",
"gradcamelementwise",
"gradcam++",
"gradcamplusplus",
"hirescam",
"layercam",
"randomcam",
"scorecam",
"xgradcam",
]
"""Supported visualisation types."""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment