Skip to content
Snippets Groups Projects
Commit 79187009 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Update alexnet model

parent dd3c5fba
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75585 failed
...@@ -6,19 +6,30 @@ ...@@ -6,19 +6,30 @@
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet from ...models.alexnet import Alexnet
# config # optimizer
optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1} optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# optimizer
optimizer = "SGD"
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model # model
model = Alexnet( model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=False,
augmentation_transforms=augmentation_transforms,
) )
...@@ -6,19 +6,30 @@ ...@@ -6,19 +6,30 @@
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet from ...models.alexnet import Alexnet
# config
optimizer_configs = {"lr": 0.001, "momentum": 0.1}
# optimizer # optimizer
optimizer = "SGD" optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model # model
model = Alexnet( model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=True,
augmentation_transforms=augmentation_transforms,
) )
...@@ -2,12 +2,15 @@ ...@@ -2,12 +2,15 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import logging
import lightning.pytorch as pl import lightning.pytorch as pl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.models as models import torchvision.models as models
import torchvision.transforms
from .normalizer import TorchVisionNormalizer logger = logging.getLogger(__name__)
class Alexnet(pl.LightningModule): class Alexnet(pl.LightningModule):
...@@ -18,25 +21,38 @@ class Alexnet(pl.LightningModule): ...@@ -18,25 +21,38 @@ class Alexnet(pl.LightningModule):
def __init__( def __init__(
self, self,
criterion, criterion=None,
criterion_valid, criterion_valid=None,
optimizer, optimizer=None,
optimizer_configs, optimizer_configs=None,
pretrained=False, pretrained=False,
augmentation_transforms=[],
): ):
super().__init__() super().__init__()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "AlexNet" self.name = "AlexNet"
# Load pretrained model self.augmentation_transforms = torchvision.transforms.Compose(
weights = ( augmentation_transforms
None if pretrained is False else models.AlexNet_Weights.DEFAULT
) )
self.model_ft = models.alexnet(weights=weights)
self.normalizer = TorchVisionNormalizer(nb_channels=1) self.criterion = criterion
self.criterion_valid = criterion_valid
self.optimizer = optimizer
self.optimizer_configs = optimizer_configs
self.normalizer = None
self.pretrained = pretrained
# Load pretrained model
if not pretrained:
weights = None
else:
logger.info("Loading pretrained model weights")
weights = models.AlexNet_Weights.DEFAULT
self.model_ft = models.alexnet(weights=weights)
# Adapt output features # Adapt output features
self.model_ft.classifier[4] = nn.Linear(4096, 512) self.model_ft.classifier[4] = nn.Linear(4096, 512)
...@@ -48,9 +64,69 @@ class Alexnet(pl.LightningModule): ...@@ -48,9 +64,69 @@ class Alexnet(pl.LightningModule):
return x return x
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initializes the normalizer for the current model.
This function is NOOP if ``pretrained = True`` (normalizer set to
imagenet weights, during contruction).
Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
A torch Dataloader from which to compute the mean and std.
Will not be used if the model is pretrained.
"""
if self.pretrained:
from .normalizer import make_imagenet_normalizer
logger.warning(
"ImageNet pre-trained densenet model - NOT "
"computing z-norm factors from training data. "
"Using preset factors from torchvision."
)
self.normalizer = make_imagenet_normalizer()
else:
from .normalizer import make_z_normalizer
logger.info(
"Uninitialised densenet model - "
"computing z-norm factors from training data."
)
self.normalizer = make_z_normalizer(dataloader)
def set_bce_loss_weights(self, datamodule):
"""Reweights loss weights if BCEWithLogitsLoss is used.
Parameters
----------
datamodule:
A datamodule implementing train_dataloader() and val_dataloader()
"""
from ..data.dataset import _get_positive_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.")
train_positive_weights = _get_positive_weights(
datamodule.train_dataloader()
)
self.criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=train_positive_weights
)
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
validation_positive_weights = _get_positive_weights(
datamodule.val_dataloader()["validation"]
)
self.criterion_valid = torch.nn.BCEWithLogitsLoss(
pos_weight=validation_positive_weights
)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[1] images = batch[0]
labels = batch[2] labels = batch[1]["label"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
...@@ -58,17 +134,20 @@ class Alexnet(pl.LightningModule): ...@@ -58,17 +134,20 @@ class Alexnet(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1)) labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network # Forward pass on the network
outputs = self(images) augmented_images = [
self.augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
# Manually move criterion to selected device, since not part of the model. training_loss = self.criterion(outputs, labels.float())
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss} return {"loss": training_loss}
def validation_step(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1] images = batch[0]
labels = batch[2] labels = batch[1]["label"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
...@@ -78,11 +157,7 @@ class Alexnet(pl.LightningModule): ...@@ -78,11 +157,7 @@ class Alexnet(pl.LightningModule):
# data forwarding on the existing network # data forwarding on the existing network
outputs = self(images) outputs = self(images)
# Manually move criterion to selected device, since not part of the model. validation_loss = self.criterion_valid(outputs, labels.float())
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
if dataloader_idx == 0: if dataloader_idx == 0:
return {"validation_loss": validation_loss} return {"validation_loss": validation_loss}
...@@ -90,8 +165,9 @@ class Alexnet(pl.LightningModule): ...@@ -90,8 +165,9 @@ class Alexnet(pl.LightningModule):
return {f"extra_validation_loss_{dataloader_idx}": validation_loss} return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0] images = batch[0]
images = batch[1] labels = batch[1]["label"]
names = batch[1]["name"]
outputs = self(images) outputs = self(images)
probabilities = torch.sigmoid(outputs) probabilities = torch.sigmoid(outputs)
...@@ -101,11 +177,8 @@ class Alexnet(pl.LightningModule): ...@@ -101,11 +177,8 @@ class Alexnet(pl.LightningModule):
if isinstance(outputs, list): if isinstance(outputs, list):
outputs = outputs[-1] outputs = outputs[-1]
return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) return names[0], torch.flatten(probabilities), torch.flatten(labels)
def configure_optimizers(self): def configure_optimizers(self):
optimizer = getattr(torch.optim, self.hparams.optimizer)( optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
self.parameters(), **self.hparams.optimizer_configs
)
return optimizer return optimizer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment