diff --git a/src/ptbench/configs/models/alexnet.py b/src/ptbench/configs/models/alexnet.py index cf8bfd35aa10ad3493be9483ba0347fc8ebb7da1..2361b886d500fee740e456f2505a36da4fdaf4e3 100644 --- a/src/ptbench/configs/models/alexnet.py +++ b/src/ptbench/configs/models/alexnet.py @@ -6,19 +6,30 @@ from torch import empty from torch.nn import BCEWithLogitsLoss +from torch.optim import SGD from ...models.alexnet import Alexnet -# config +# optimizer +optimizer = SGD optimizer_configs = {"lr": 0.01, "momentum": 0.1} -# optimizer -optimizer = "SGD" # criterion criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +from ...data.transforms import ElasticDeformation + +augmentation_transforms = [ + ElasticDeformation(p=0.8), +] + # model model = Alexnet( - criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False + criterion, + criterion_valid, + optimizer, + optimizer_configs, + pretrained=False, + augmentation_transforms=augmentation_transforms, ) diff --git a/src/ptbench/configs/models/alexnet_pretrained.py b/src/ptbench/configs/models/alexnet_pretrained.py index 1d196be6f79ea5c70987c1d1a66eaf32e8e7ca4c..0dc7e5d67d007cf5e7e358e7fa75243a47047c4b 100644 --- a/src/ptbench/configs/models/alexnet_pretrained.py +++ b/src/ptbench/configs/models/alexnet_pretrained.py @@ -6,19 +6,30 @@ from torch import empty from torch.nn import BCEWithLogitsLoss +from torch.optim import SGD from ...models.alexnet import Alexnet -# config -optimizer_configs = {"lr": 0.001, "momentum": 0.1} - # optimizer -optimizer = "SGD" +optimizer = SGD +optimizer_configs = {"lr": 0.01, "momentum": 0.1} + # criterion criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +from ...data.transforms import ElasticDeformation + +augmentation_transforms = [ + ElasticDeformation(p=0.8), +] + # model model = Alexnet( - criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True + criterion, + criterion_valid, + optimizer, + optimizer_configs, + pretrained=True, + augmentation_transforms=augmentation_transforms, ) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index ba9bf05f7428d759489bd744f8ec35c3b43bab02..55898b6759e4e471607bbe87cff0de3fb074724c 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -2,12 +2,15 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import logging + import lightning.pytorch as pl import torch import torch.nn as nn import torchvision.models as models +import torchvision.transforms -from .normalizer import TorchVisionNormalizer +logger = logging.getLogger(__name__) class Alexnet(pl.LightningModule): @@ -18,25 +21,38 @@ class Alexnet(pl.LightningModule): def __init__( self, - criterion, - criterion_valid, - optimizer, - optimizer_configs, + criterion=None, + criterion_valid=None, + optimizer=None, + optimizer_configs=None, pretrained=False, + augmentation_transforms=[], ): super().__init__() - self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) - self.name = "AlexNet" - # Load pretrained model - weights = ( - None if pretrained is False else models.AlexNet_Weights.DEFAULT + self.augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms ) - self.model_ft = models.alexnet(weights=weights) - self.normalizer = TorchVisionNormalizer(nb_channels=1) + self.criterion = criterion + self.criterion_valid = criterion_valid + + self.optimizer = optimizer + self.optimizer_configs = optimizer_configs + + self.normalizer = None + self.pretrained = pretrained + + # Load pretrained model + if not pretrained: + weights = None + else: + logger.info("Loading pretrained model weights") + weights = models.AlexNet_Weights.DEFAULT + + self.model_ft = models.alexnet(weights=weights) # Adapt output features self.model_ft.classifier[4] = nn.Linear(4096, 512) @@ -48,9 +64,69 @@ class Alexnet(pl.LightningModule): return x + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initializes the normalizer for the current model. + + This function is NOOP if ``pretrained = True`` (normalizer set to + imagenet weights, during contruction). + + Parameters + ---------- + + dataloader: :py:class:`torch.utils.data.DataLoader` + A torch Dataloader from which to compute the mean and std. + Will not be used if the model is pretrained. + """ + if self.pretrained: + from .normalizer import make_imagenet_normalizer + + logger.warning( + "ImageNet pre-trained densenet model - NOT " + "computing z-norm factors from training data. " + "Using preset factors from torchvision." + ) + self.normalizer = make_imagenet_normalizer() + else: + from .normalizer import make_z_normalizer + + logger.info( + "Uninitialised densenet model - " + "computing z-norm factors from training data." + ) + self.normalizer = make_z_normalizer(dataloader) + + def set_bce_loss_weights(self, datamodule): + """Reweights loss weights if BCEWithLogitsLoss is used. + + Parameters + ---------- + + datamodule: + A datamodule implementing train_dataloader() and val_dataloader() + """ + from ..data.dataset import _get_positive_weights + + if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss training criterion.") + train_positive_weights = _get_positive_weights( + datamodule.train_dataloader() + ) + self.criterion = torch.nn.BCEWithLogitsLoss( + pos_weight=train_positive_weights + ) + + if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss validation criterion.") + validation_positive_weights = _get_positive_weights( + datamodule.val_dataloader()["validation"] + ) + self.criterion_valid = torch.nn.BCEWithLogitsLoss( + pos_weight=validation_positive_weights + ) + def training_step(self, batch, batch_idx): - images = batch[1] - labels = batch[2] + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -58,17 +134,20 @@ class Alexnet(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - outputs = self(images) + augmented_images = [ + self.augmentation_transforms(img).to(self.device) for img in images + ] + # Combine list of augmented images back into a tensor + augmented_images = torch.cat(augmented_images, 0).view(images.shape) + outputs = self(augmented_images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion = self.hparams.criterion.to(self.device) - training_loss = self.hparams.criterion(outputs, labels.float()) + training_loss = self.criterion(outputs, labels.float()) return {"loss": training_loss} def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[1] - labels = batch[2] + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -78,11 +157,7 @@ class Alexnet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion_valid = self.hparams.criterion_valid.to( - self.device - ) - validation_loss = self.hparams.criterion_valid(outputs, labels.float()) + validation_loss = self.criterion_valid(outputs, labels.float()) if dataloader_idx == 0: return {"validation_loss": validation_loss} @@ -90,8 +165,9 @@ class Alexnet(pl.LightningModule): return {f"extra_validation_loss_{dataloader_idx}": validation_loss} def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - names = batch[0] - images = batch[1] + images = batch[0] + labels = batch[1]["label"] + names = batch[1]["name"] outputs = self(images) probabilities = torch.sigmoid(outputs) @@ -101,11 +177,8 @@ class Alexnet(pl.LightningModule): if isinstance(outputs, list): outputs = outputs[-1] - return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + return names[0], torch.flatten(probabilities), torch.flatten(labels) def configure_optimizers(self): - optimizer = getattr(torch.optim, self.hparams.optimizer)( - self.parameters(), **self.hparams.optimizer_configs - ) - + optimizer = self.optimizer(self.parameters(), **self.optimizer_configs) return optimizer