diff --git a/src/ptbench/configs/models/alexnet.py b/src/ptbench/configs/models/alexnet.py index 5917db2146ac2a1efa4584cbc6126a3e88ae7f86..ecaee487d1878b1e73bdee8261f06ee475834106 100644 --- a/src/ptbench/configs/models/alexnet.py +++ b/src/ptbench/configs/models/alexnet.py @@ -4,19 +4,19 @@ """AlexNet.""" +from torch import empty from torch.nn import BCEWithLogitsLoss -from torch.optim import SGD -from ...models.alexnet import build_alexnet +from ...models.alexnet import Alexnet # config -lr = 0.01 - -# model -model = build_alexnet(pretrained=False) +optimizer_configs = {"lr": 0.01, "momentum": 0.1} # optimizer -optimizer = SGD(model.parameters(), lr=lr, momentum=0.1) - +optimizer = "SGD" # criterion -criterion = BCEWithLogitsLoss() +criterion = BCEWithLogitsLoss(pos_weight=empty(1)) +criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) + +# model +model = Alexnet(criterion, criterion_valid, optimizer, optimizer_configs) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index ea096ecbdb51f05679d37594fa41d7c4788d8874..7aaaccb6a18e12e5e0d03c00e9f0974d470b31df 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -2,29 +2,45 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from collections import OrderedDict - +import pytorch_lightning as pl +import torch import torch.nn as nn import torchvision.models as models from .normalizer import TorchVisionNormalizer -class Alexnet(nn.Module): +class Alexnet(pl.LightningModule): """Alexnet module. Note: only usable with a normalized dataset """ - def __init__(self, pretrained=False): + def __init__( + self, + criterion, + criterion_valid, + optimizer, + optimizer_configs, + pretrained=False, + ): super().__init__() + self.save_hyperparameters() + + self.criterion = criterion + self.criterion_valid = criterion_valid + + self.name = "AlexNet" + # Load pretrained model weights = ( None if pretrained is False else models.AlexNet_Weights.DEFAULT ) self.model_ft = models.alexnet(weights=weights) + self.normalizer = TorchVisionNormalizer(nb_channels=1) + # Adapt output features self.model_ft.classifier[4] = nn.Linear(4096, 512) self.model_ft.classifier[6] = nn.Linear(512, 1) @@ -44,20 +60,59 @@ class Alexnet(nn.Module): tensor : :py:class:`torch.Tensor` """ - return self.model_ft(x) + x = self.normalizer(x) + x = self.model_ft(x) + return x -def build_alexnet(pretrained=False): - """Build Alexnet CNN. + def training_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] - Returns - ------- + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) - module : :py:class:`torch.nn.Module` - """ - model = Alexnet(pretrained=pretrained) - model = [("normalizer", TorchVisionNormalizer()), ("model", model)] - model = nn.Sequential(OrderedDict(model)) + # Forward pass on the network + outputs = self(images) + + training_loss = self.criterion(outputs, labels.double()) + + return {"loss": training_loss} + + def validation_step(self, batch, batch_idx): + images = batch[1] + labels = batch[2] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # data forwarding on the existing network + outputs = self(images) + validation_loss = self.criterion_valid(outputs, labels.double()) + + return {"validation_loss": validation_loss} + + def predict_step(self, batch, batch_idx, grad_cams=False): + names = batch[0] + images = batch[1] + + outputs = self(images) + probabilities = torch.sigmoid(outputs) + + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] + + return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + + def configure_optimizers(self): + optimizer = getattr(torch.optim, self.hparams.optimizer)( + self.parameters(), **self.hparams.optimizer_configs + ) - model.name = "AlexNet" - return model + return optimizer