Skip to content
Snippets Groups Projects
Commit 56ce93bd authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models] Make alexnet and densenet potentially multi-class classifiers

parent bd852ec3
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -58,6 +58,9 @@ class Alexnet(pl.LightningModule):
pretrained
If set to True, loads pretrained model weights during initialization,
else trains a new model.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
......@@ -68,6 +71,7 @@ class Alexnet(pl.LightningModule):
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
pretrained: bool = False,
num_classes: int = 1,
):
super().__init__()
......@@ -104,7 +108,7 @@ class Alexnet(pl.LightningModule):
# Adapt output features
self.model_ft.classifier[4] = torch.nn.Linear(4096, 512)
self.model_ft.classifier[6] = torch.nn.Linear(512, 1)
self.model_ft.classifier[6] = torch.nn.Linear(512, num_classes)
def forward(self, x):
x = self.normalizer(x) # type: ignore
......
......@@ -56,6 +56,9 @@ class Densenet(pl.LightningModule):
pretrained
If set to True, loads pretrained model weights during initialization,
else trains a new model.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
......@@ -66,6 +69,7 @@ class Densenet(pl.LightningModule):
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
pretrained: bool = False,
num_classes: int = 1,
):
super().__init__()
......@@ -101,9 +105,7 @@ class Densenet(pl.LightningModule):
self.model_ft = models.densenet121(weights=weights)
# Adapt output features
self.model_ft.classifier = torch.nn.Sequential(
torch.nn.Linear(1024, 256), torch.nn.Linear(256, 1)
)
self.model_ft.classifier = torch.nn.Linear(1024, num_classes)
def forward(self, x):
x = self.normalizer(x) # type: ignore
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment