Skip to content
Snippets Groups Projects
Commit 56ce93bd authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models] Make alexnet and densenet potentially multi-class classifiers

parent bd852ec3
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -58,6 +58,9 @@ class Alexnet(pl.LightningModule): ...@@ -58,6 +58,9 @@ class Alexnet(pl.LightningModule):
pretrained pretrained
If set to True, loads pretrained model weights during initialization, If set to True, loads pretrained model weights during initialization,
else trains a new model. else trains a new model.
num_classes
Number of outputs (classes) for this model.
""" """
def __init__( def __init__(
...@@ -68,6 +71,7 @@ class Alexnet(pl.LightningModule): ...@@ -68,6 +71,7 @@ class Alexnet(pl.LightningModule):
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
pretrained: bool = False, pretrained: bool = False,
num_classes: int = 1,
): ):
super().__init__() super().__init__()
...@@ -104,7 +108,7 @@ class Alexnet(pl.LightningModule): ...@@ -104,7 +108,7 @@ class Alexnet(pl.LightningModule):
# Adapt output features # Adapt output features
self.model_ft.classifier[4] = torch.nn.Linear(4096, 512) self.model_ft.classifier[4] = torch.nn.Linear(4096, 512)
self.model_ft.classifier[6] = torch.nn.Linear(512, 1) self.model_ft.classifier[6] = torch.nn.Linear(512, num_classes)
def forward(self, x): def forward(self, x):
x = self.normalizer(x) # type: ignore x = self.normalizer(x) # type: ignore
......
...@@ -56,6 +56,9 @@ class Densenet(pl.LightningModule): ...@@ -56,6 +56,9 @@ class Densenet(pl.LightningModule):
pretrained pretrained
If set to True, loads pretrained model weights during initialization, If set to True, loads pretrained model weights during initialization,
else trains a new model. else trains a new model.
num_classes
Number of outputs (classes) for this model.
""" """
def __init__( def __init__(
...@@ -66,6 +69,7 @@ class Densenet(pl.LightningModule): ...@@ -66,6 +69,7 @@ class Densenet(pl.LightningModule):
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
pretrained: bool = False, pretrained: bool = False,
num_classes: int = 1,
): ):
super().__init__() super().__init__()
...@@ -101,9 +105,7 @@ class Densenet(pl.LightningModule): ...@@ -101,9 +105,7 @@ class Densenet(pl.LightningModule):
self.model_ft = models.densenet121(weights=weights) self.model_ft = models.densenet121(weights=weights)
# Adapt output features # Adapt output features
self.model_ft.classifier = torch.nn.Sequential( self.model_ft.classifier = torch.nn.Linear(1024, num_classes)
torch.nn.Linear(1024, 256), torch.nn.Linear(256, 1)
)
def forward(self, x): def forward(self, x):
x = self.normalizer(x) # type: ignore x = self.normalizer(x) # type: ignore
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment