Skip to content
Snippets Groups Projects
Commit 2bd2a302 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models] Remove loss-weight-balancing pending on issue #6 resolution

parent b2ff0f0d
No related branches found
Tags v5.0.0
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -13,7 +13,7 @@ import torch.utils.data
import torchvision.models as models
import torchvision.transforms
from ..data.typing import DataLoader, TransformSequence
from ..data.typing import TransformSequence
from .typing import Checkpoint
logger = logging.getLogger(__name__)
......@@ -56,7 +56,8 @@ class Alexnet(pl.LightningModule):
applied on the input **before** it is fed into the network.
pretrained
If set to True, loads pretrained model weights during initialization, else trains a new model.
If set to True, loads pretrained model weights during initialization,
else trains a new model.
"""
def __init__(
......@@ -108,7 +109,8 @@ class Alexnet(pl.LightningModule):
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning to restore your model.
If you saved something with on_save_checkpoint() this is your chance to restore this.
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
......@@ -162,48 +164,6 @@ class Alexnet(pl.LightningModule):
)
self.normalizer = make_z_normalizer(dataloader)
def balance_losses_by_class(
self, train_dataloader: DataLoader, valid_dataloader: DataLoader
):
"""Reweights loss weights if possible.
Parameters
----------
train_dataloader
The data loader to use for training
valid_dataloader
The data loader to use for validation
Raises
------
RuntimeError
If train or validation losses are not of type
:py:class:`torch.nn.BCEWithLogitsLoss`.
"""
from .loss_weights import get_label_weights
if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training loss.")
weights = get_label_weights(train_dataloader)
self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
else:
raise RuntimeError(
"Training loss is not BCEWithLogitsLoss - dunno how to balance"
)
if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation loss.")
weights = get_label_weights(valid_dataloader)
self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
else:
raise RuntimeError(
"Validation loss is not BCEWithLogitsLoss - dunno how to balance"
)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
......
......@@ -13,7 +13,7 @@ import torch.utils.data
import torchvision.models as models
import torchvision.transforms
from ..data.typing import DataLoader, TransformSequence
from ..data.typing import TransformSequence
from .typing import Checkpoint
logger = logging.getLogger(__name__)
......@@ -54,7 +54,8 @@ class Densenet(pl.LightningModule):
applied on the input **before** it is fed into the network.
pretrained
If set to True, loads pretrained model weights during initialization, else trains a new model.
If set to True, loads pretrained model weights during initialization,
else trains a new model.
"""
def __init__(
......@@ -107,7 +108,8 @@ class Densenet(pl.LightningModule):
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning to restore your model.
If you saved something with on_save_checkpoint() this is your chance to restore this.
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
......@@ -161,50 +163,6 @@ class Densenet(pl.LightningModule):
)
self.normalizer = make_z_normalizer(dataloader)
def balance_losses_by_class(
self,
train_dataloader: DataLoader,
valid_dataloader: dict[str, DataLoader],
):
"""Reweights loss weights if possible.
Parameters
----------
train_dataloader
The data loader to use for training
valid_dataloader
The data loaders to use for each of the validation sets
Raises
------
RuntimeError
If train or validation losses are not of type
:py:class:`torch.nn.BCEWithLogitsLoss`.
"""
from .loss_weights import get_label_weights
if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training loss.")
weights = get_label_weights(train_dataloader)
self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
else:
raise RuntimeError(
"Training loss is not BCEWithLogitsLoss - dunno how to balance"
)
if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation loss.")
weights = get_label_weights(valid_dataloader)
self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
else:
raise RuntimeError(
"Validation loss is not BCEWithLogitsLoss - dunno how to balance"
)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
......
......@@ -7,10 +7,12 @@ import logging
import torch
import torch.utils.data
from ..data.typing import DataLoader
logger = logging.getLogger(__name__)
def get_label_weights(
def _get_label_weights(
dataloader: torch.utils.data.DataLoader,
) -> torch.Tensor:
"""Computes the weights of each class of a DataLoader.
......@@ -68,3 +70,23 @@ def get_label_weights(
)
return positive_weights
def make_balanced_bcewithlogitsloss(
dataloader: DataLoader,
) -> torch.nn.BCEWithLogitsLoss:
"""Returns a balanced binary-cross-entropy loss.
The loss is weighted using the ratio between positives and total examples
available.
Returns
-------
loss
An instance of the weighted loss
"""
weights = _get_label_weights(dataloader)
return torch.nn.BCEWithLogitsLoss(pos_weight=weights)
......@@ -13,7 +13,7 @@ import torch.optim.optimizer
import torch.utils.data
import torchvision.transforms
from ..data.typing import DataLoader, TransformSequence
from ..data.typing import TransformSequence
from .typing import Checkpoint
logger = logging.getLogger(__name__)
......@@ -192,7 +192,8 @@ class Pasa(pl.LightningModule):
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning to restore your model.
If you saved something with on_save_checkpoint() this is your chance to restore this.
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
......@@ -232,50 +233,6 @@ class Pasa(pl.LightningModule):
)
self.normalizer = make_z_normalizer(dataloader)
def balance_losses_by_class(
self,
train_dataloader: DataLoader,
valid_dataloader: dict[str, DataLoader],
):
"""Reweights loss weights if possible.
Parameters
----------
train_dataloader
The data loader to use for training
valid_dataloader
The data loaders to use for each of the validation sets
Raises
------
RuntimeError
If train or validation losses are not of type
:py:class:`torch.nn.BCEWithLogitsLoss`.
"""
from .loss_weights import get_label_weights
if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training loss.")
weights = get_label_weights(train_dataloader)
self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
else:
raise RuntimeError(
"Training loss is not BCEWithLogitsLoss - dunno how to balance"
)
if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation loss.")
weights = get_label_weights(valid_dataloader)
self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
else:
raise RuntimeError(
"Validation loss is not BCEWithLogitsLoss - dunno how to balance"
)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment