From 6f12ad95f3c3cf6075a9ae5244237304bb960d49 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Thu, 20 Jul 2023 15:52:19 +0200 Subject: [PATCH] [models] Make alexnet and densenet potentially multi-class classifiers --- src/ptbench/models/alexnet.py | 6 +++++- src/ptbench/models/densenet.py | 8 +++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index 0b19b3d7..df224bc5 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -58,6 +58,9 @@ class Alexnet(pl.LightningModule): pretrained If set to True, loads pretrained model weights during initialization, else trains a new model. + + num_classes + Number of outputs (classes) for this model. """ def __init__( @@ -68,6 +71,7 @@ class Alexnet(pl.LightningModule): optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], pretrained: bool = False, + num_classes: int = 1, ): super().__init__() @@ -104,7 +108,7 @@ class Alexnet(pl.LightningModule): # Adapt output features self.model_ft.classifier[4] = torch.nn.Linear(4096, 512) - self.model_ft.classifier[6] = torch.nn.Linear(512, 1) + self.model_ft.classifier[6] = torch.nn.Linear(512, num_classes) def forward(self, x): x = self.normalizer(x) # type: ignore diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index f6eb2cb6..436d73ba 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -56,6 +56,9 @@ class Densenet(pl.LightningModule): pretrained If set to True, loads pretrained model weights during initialization, else trains a new model. + + num_classes + Number of outputs (classes) for this model. """ def __init__( @@ -66,6 +69,7 @@ class Densenet(pl.LightningModule): optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], pretrained: bool = False, + num_classes: int = 1, ): super().__init__() @@ -101,9 +105,7 @@ class Densenet(pl.LightningModule): self.model_ft = models.densenet121(weights=weights) # Adapt output features - self.model_ft.classifier = torch.nn.Sequential( - torch.nn.Linear(1024, 256), torch.nn.Linear(256, 1) - ) + self.model_ft.classifier = torch.nn.Linear(1024, num_classes) def forward(self, x): x = self.normalizer(x) # type: ignore -- GitLab