Skip to content
Snippets Groups Projects
Commit 79187009 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Update alexnet model

parent dd3c5fba
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75585 failed
......@@ -6,19 +6,30 @@
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet
# config
# optimizer
optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# optimizer
optimizer = "SGD"
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model
model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False
criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=False,
augmentation_transforms=augmentation_transforms,
)
......@@ -6,19 +6,30 @@
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet
# config
optimizer_configs = {"lr": 0.001, "momentum": 0.1}
# optimizer
optimizer = "SGD"
optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model
model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True
criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=True,
augmentation_transforms=augmentation_transforms,
)
......@@ -2,12 +2,15 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms
from .normalizer import TorchVisionNormalizer
logger = logging.getLogger(__name__)
class Alexnet(pl.LightningModule):
......@@ -18,25 +21,38 @@ class Alexnet(pl.LightningModule):
def __init__(
self,
criterion,
criterion_valid,
optimizer,
optimizer_configs,
criterion=None,
criterion_valid=None,
optimizer=None,
optimizer_configs=None,
pretrained=False,
augmentation_transforms=[],
):
super().__init__()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "AlexNet"
# Load pretrained model
weights = (
None if pretrained is False else models.AlexNet_Weights.DEFAULT
self.augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
)
self.model_ft = models.alexnet(weights=weights)
self.normalizer = TorchVisionNormalizer(nb_channels=1)
self.criterion = criterion
self.criterion_valid = criterion_valid
self.optimizer = optimizer
self.optimizer_configs = optimizer_configs
self.normalizer = None
self.pretrained = pretrained
# Load pretrained model
if not pretrained:
weights = None
else:
logger.info("Loading pretrained model weights")
weights = models.AlexNet_Weights.DEFAULT
self.model_ft = models.alexnet(weights=weights)
# Adapt output features
self.model_ft.classifier[4] = nn.Linear(4096, 512)
......@@ -48,9 +64,69 @@ class Alexnet(pl.LightningModule):
return x
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initializes the normalizer for the current model.
This function is NOOP if ``pretrained = True`` (normalizer set to
imagenet weights, during contruction).
Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
A torch Dataloader from which to compute the mean and std.
Will not be used if the model is pretrained.
"""
if self.pretrained:
from .normalizer import make_imagenet_normalizer
logger.warning(
"ImageNet pre-trained densenet model - NOT "
"computing z-norm factors from training data. "
"Using preset factors from torchvision."
)
self.normalizer = make_imagenet_normalizer()
else:
from .normalizer import make_z_normalizer
logger.info(
"Uninitialised densenet model - "
"computing z-norm factors from training data."
)
self.normalizer = make_z_normalizer(dataloader)
def set_bce_loss_weights(self, datamodule):
"""Reweights loss weights if BCEWithLogitsLoss is used.
Parameters
----------
datamodule:
A datamodule implementing train_dataloader() and val_dataloader()
"""
from ..data.dataset import _get_positive_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.")
train_positive_weights = _get_positive_weights(
datamodule.train_dataloader()
)
self.criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=train_positive_weights
)
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
validation_positive_weights = _get_positive_weights(
datamodule.val_dataloader()["validation"]
)
self.criterion_valid = torch.nn.BCEWithLogitsLoss(
pos_weight=validation_positive_weights
)
def training_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
......@@ -58,17 +134,20 @@ class Alexnet(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(images)
augmented_images = [
self.augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.float())
training_loss = self.criterion(outputs, labels.float())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1]
labels = batch[2]
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
......@@ -78,11 +157,7 @@ class Alexnet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
validation_loss = self.criterion_valid(outputs, labels.float())
if dataloader_idx == 0:
return {"validation_loss": validation_loss}
......@@ -90,8 +165,9 @@ class Alexnet(pl.LightningModule):
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0]
images = batch[1]
images = batch[0]
labels = batch[1]["label"]
names = batch[1]["name"]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
......@@ -101,11 +177,8 @@ class Alexnet(pl.LightningModule):
if isinstance(outputs, list):
outputs = outputs[-1]
return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
return names[0], torch.flatten(probabilities), torch.flatten(labels)
def configure_optimizers(self):
optimizer = getattr(torch.optim, self.hparams.optimizer)(
self.parameters(), **self.hparams.optimizer_configs
)
optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
return optimizer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment