Skip to content
Snippets Groups Projects
Commit 7eaac22c authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Updated models

parent 7b12973c
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75758 failed
This commit is part of merge request !6. Comments created here will be created in the context of that merge request.
......@@ -229,11 +229,11 @@ def _make_balanced_random_sampler(
1. The probability of picking a sample from any target is the same (0.5 in
this case). To verify this, notice that the probability of picking a
sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`.
2. The probabiility of picking a sample with ``target=0`` from Dataset 2 is
2. The probability of picking a sample with ``target=0`` from Dataset 2 is
3 times higher than those from Dataset 1. As there are 3 times less
samples in Dataset 2 with ``target=0``, this makes choosing samples from
Dataset 1 proportionally less likely.
3. The probabiility of picking a sample with ``target=1`` from Dataset 2 is
3. The probability of picking a sample with ``target=1`` from Dataset 2 is
3 times lower than those from Dataset 1. As there are 3 times less
samples in Dataset 1 with ``target=1``, this makes choosing samples from
Dataset 2 proportionally less likely.
......
......@@ -3,14 +3,18 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
import lightning.pytorch as pl
import torch.nn as nn
import torch
import torch.nn
import torch.nn.functional as F
import torch.optim.optimizer
import torch.utils.data
import torchvision.models as models
import torchvision.transforms
from ..data.typing import DataLoader
from ..data.typing import DataLoader, TransformSequence
logger = logging.getLogger(__name__)
......@@ -19,32 +23,66 @@ class Alexnet(pl.LightningModule):
"""Alexnet module.
Note: only usable with a normalized dataset
Parameters
----------
train_loss
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
optimizer_type
The type of optimizer to use for training
optimizer_arguments
Arguments to the optimizer after ``params``.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
pretrained
If set to True, loads pretrained model weights during initialization, else trains a new model.
"""
def __init__(
self,
criterion=None,
criterion_valid=None,
optimizer=None,
optimizer_configs=None,
pretrained=False,
augmentation_transforms=[],
train_loss: torch.nn.Module,
validation_loss: torch.nn.Module | None,
optimizer_type: type[torch.optim.Optimizer],
optimizer_arguments: dict[str, typing.Any],
augmentation_transforms: TransformSequence = [],
pretrained: bool = False,
):
super().__init__()
self.name = "alexnet"
self.augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
self._train_loss = train_loss
self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self.criterion = criterion
self.criterion_valid = criterion_valid
self.optimizer = optimizer
self.optimizer_configs = optimizer_configs
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
)
self.normalizer = None
self.pretrained = pretrained
# Load pretrained model
......@@ -57,11 +95,12 @@ class Alexnet(pl.LightningModule):
self.model_ft = models.alexnet(weights=weights)
# Adapt output features
self.model_ft.classifier[4] = nn.Linear(4096, 512)
self.model_ft.classifier[6] = nn.Linear(512, 1)
self.model_ft.classifier[4] = torch.nn.Linear(4096, 512)
self.model_ft.classifier[6] = torch.nn.Linear(512, 1)
def forward(self, x):
x = self.normalizer(x)
x = self.normalizer(x) # type: ignore
x = self.model_ft(x)
return x
......@@ -121,25 +160,25 @@ class Alexnet(pl.LightningModule):
"""
from .loss_weights import get_label_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.")
if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training loss.")
weights = get_label_weights(train_dataloader)
self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
else:
raise RuntimeError(
"Training loss is not BCEWithLogitsLoss - dunno how to balance"
)
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation loss.")
weights = get_label_weights(valid_dataloader)
self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights)
self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
else:
raise RuntimeError(
"Validation loss is not BCEWithLogitsLoss - dunno how to balance"
)
def training_step(self, batch, batch_idx):
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
......@@ -150,15 +189,13 @@ class Alexnet(pl.LightningModule):
# Forward pass on the network
augmented_images = [
self.augmentation_transforms(img).to(self.device) for img in images
self._augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
training_loss = self.criterion(outputs, labels.float())
return {"loss": training_loss}
return self._train_loss(outputs, labels.float())
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
......@@ -172,23 +209,23 @@ class Alexnet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.float())
if dataloader_idx == 0:
return {"validation_loss": validation_loss}
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
images = batch[0]
labels = batch[1]["label"]
names = batch[1]["name"]
names = batch[1]["names"]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
return names[0], torch.flatten(probabilities), torch.flatten(labels)
return (
names[0],
torch.flatten(probabilities),
torch.flatten(labels),
)
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
return optimizer
return self._optimizer_type(
self.parameters(), **self._optimizer_arguments
)
......@@ -3,15 +3,18 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.nn.functional as F
import torch.optim.optimizer
import torch.utils.data
import torchvision.models as models
import torchvision.transforms
from ..data.typing import DataLoader
from ..data.typing import DataLoader, TransformSequence
logger = logging.getLogger(__name__)
......@@ -22,32 +25,61 @@ class Densenet(pl.LightningModule):
Parameters
----------
criterion
A dictionary containing the criteria for the
train_loss
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
optimizer_type
The type of optimizer to use for training
optimizer_arguments
Arguments to the optimizer after ``params``.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
pretrained
If set to True, loads pretrained model weights during initialization, else trains a new model.
"""
def __init__(
self,
criterion=None,
criterion_valid=None,
optimizer=None,
optimizer_configs=None,
pretrained=False,
augmentation_transforms=[],
train_loss: torch.nn.Module,
validation_loss: torch.nn.Module | None,
optimizer_type: type[torch.optim.Optimizer],
optimizer_arguments: dict[str, typing.Any],
augmentation_transforms: TransformSequence = [],
pretrained: bool= False,
):
super().__init__()
self.name = "densenet-121"
self.augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
self._train_loss = train_loss
self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self.criterion = criterion
self.criterion_valid = criterion_valid
self.optimizer = optimizer
self.optimizer_configs = optimizer_configs
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
)
self.pretrained = pretrained
......@@ -66,7 +98,9 @@ class Densenet(pl.LightningModule):
)
def forward(self, x):
x = self.normalizer(x)
x = self.normalizer(x) # type: ignore
x = self.model_ft(x)
return x
......@@ -128,25 +162,25 @@ class Densenet(pl.LightningModule):
"""
from .loss_weights import get_label_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.")
if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training loss.")
weights = get_label_weights(train_dataloader)
self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
else:
raise RuntimeError(
"Training loss is not BCEWithLogitsLoss - dunno how to balance"
)
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation loss.")
weights = get_label_weights(valid_dataloader)
self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights)
self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
else:
raise RuntimeError(
"Validation loss is not BCEWithLogitsLoss - dunno how to balance"
)
def training_step(self, batch, batch_idx):
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
......@@ -157,15 +191,13 @@ class Densenet(pl.LightningModule):
# Forward pass on the network
augmented_images = [
self.augmentation_transforms(img).to(self.device) for img in images
self._augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
training_loss = self.criterion(outputs, labels.float())
return {"loss": training_loss}
return self._train_loss(outputs, labels.float())
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
......@@ -179,23 +211,23 @@ class Densenet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.float())
if dataloader_idx == 0:
return {"validation_loss": validation_loss}
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
images = batch[0]
labels = batch[1]["label"]
names = batch[1]["name"]
names = batch[1]["names"]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
return names[0], torch.flatten(probabilities), torch.flatten(labels)
return (
names[0],
torch.flatten(probabilities),
torch.flatten(labels),
)
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
return optimizer
return self._optimizer_type(
self.parameters(), **self._optimizer_arguments
)
......@@ -13,7 +13,7 @@ import torch.optim.optimizer
import torch.utils.data
import torchvision.transforms
from ..data.typing import TransformSequence
from ..data.typing import DataLoader, TransformSequence
logger = logging.getLogger(__name__)
......@@ -202,6 +202,50 @@ class Pasa(pl.LightningModule):
)
self.normalizer = make_z_normalizer(dataloader)
def balance_losses_by_class(
self,
train_dataloader: DataLoader,
valid_dataloader: dict[str, DataLoader],
):
"""Reweights loss weights if possible.
Parameters
----------
train_dataloader
The data loader to use for training
valid_dataloader
The data loaders to use for each of the validation sets
Raises
------
RuntimeError
If train or validation losses are not of type
:py:class:`torch.nn.BCEWithLogitsLoss`.
"""
from .loss_weights import get_label_weights
if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training loss.")
weights = get_label_weights(train_dataloader)
self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
else:
raise RuntimeError(
"Training loss is not BCEWithLogitsLoss - dunno how to balance"
)
if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation loss.")
weights = get_label_weights(valid_dataloader)
self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
else:
raise RuntimeError(
"Validation loss is not BCEWithLogitsLoss - dunno how to balance"
)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment