diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index 0b19b3d79fed91e21972daf337947b88e172f908..df224bc5f2eb973003b87a8170b0b53fd70b5aad 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -58,6 +58,9 @@ class Alexnet(pl.LightningModule): pretrained If set to True, loads pretrained model weights during initialization, else trains a new model. + + num_classes + Number of outputs (classes) for this model. """ def __init__( @@ -68,6 +71,7 @@ class Alexnet(pl.LightningModule): optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], pretrained: bool = False, + num_classes: int = 1, ): super().__init__() @@ -104,7 +108,7 @@ class Alexnet(pl.LightningModule): # Adapt output features self.model_ft.classifier[4] = torch.nn.Linear(4096, 512) - self.model_ft.classifier[6] = torch.nn.Linear(512, 1) + self.model_ft.classifier[6] = torch.nn.Linear(512, num_classes) def forward(self, x): x = self.normalizer(x) # type: ignore diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index f6eb2cb670ba400c00b29b3122f4d4f128f37200..436d73ba55d5829cd886b37ebee882b80c41d2d8 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -56,6 +56,9 @@ class Densenet(pl.LightningModule): pretrained If set to True, loads pretrained model weights during initialization, else trains a new model. + + num_classes + Number of outputs (classes) for this model. """ def __init__( @@ -66,6 +69,7 @@ class Densenet(pl.LightningModule): optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], pretrained: bool = False, + num_classes: int = 1, ): super().__init__() @@ -101,9 +105,7 @@ class Densenet(pl.LightningModule): self.model_ft = models.densenet121(weights=weights) # Adapt output features - self.model_ft.classifier = torch.nn.Sequential( - torch.nn.Linear(1024, 256), torch.nn.Linear(256, 1) - ) + self.model_ft.classifier = torch.nn.Linear(1024, num_classes) def forward(self, x): x = self.normalizer(x) # type: ignore