Skip to content
Snippets Groups Projects
Commit 2bd2a302 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models] Remove loss-weight-balancing pending on issue #6 resolution

parent b2ff0f0d
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -13,7 +13,7 @@ import torch.utils.data ...@@ -13,7 +13,7 @@ import torch.utils.data
import torchvision.models as models import torchvision.models as models
import torchvision.transforms import torchvision.transforms
from ..data.typing import DataLoader, TransformSequence from ..data.typing import TransformSequence
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -56,7 +56,8 @@ class Alexnet(pl.LightningModule): ...@@ -56,7 +56,8 @@ class Alexnet(pl.LightningModule):
applied on the input **before** it is fed into the network. applied on the input **before** it is fed into the network.
pretrained pretrained
If set to True, loads pretrained model weights during initialization, else trains a new model. If set to True, loads pretrained model weights during initialization,
else trains a new model.
""" """
def __init__( def __init__(
...@@ -108,7 +109,8 @@ class Alexnet(pl.LightningModule): ...@@ -108,7 +109,8 @@ class Alexnet(pl.LightningModule):
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning to restore your model. """Called by Lightning to restore your model.
If you saved something with on_save_checkpoint() this is your chance to restore this. If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters Parameters
---------- ----------
...@@ -162,48 +164,6 @@ class Alexnet(pl.LightningModule): ...@@ -162,48 +164,6 @@ class Alexnet(pl.LightningModule):
) )
self.normalizer = make_z_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def balance_losses_by_class(
self, train_dataloader: DataLoader, valid_dataloader: DataLoader
):
"""Reweights loss weights if possible.
Parameters
----------
train_dataloader
The data loader to use for training
valid_dataloader
The data loader to use for validation
Raises
------
RuntimeError
If train or validation losses are not of type
:py:class:`torch.nn.BCEWithLogitsLoss`.
"""
from .loss_weights import get_label_weights
if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training loss.")
weights = get_label_weights(train_dataloader)
self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
else:
raise RuntimeError(
"Training loss is not BCEWithLogitsLoss - dunno how to balance"
)
if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation loss.")
weights = get_label_weights(valid_dataloader)
self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
else:
raise RuntimeError(
"Validation loss is not BCEWithLogitsLoss - dunno how to balance"
)
def training_step(self, batch, _): def training_step(self, batch, _):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["label"]
......
...@@ -13,7 +13,7 @@ import torch.utils.data ...@@ -13,7 +13,7 @@ import torch.utils.data
import torchvision.models as models import torchvision.models as models
import torchvision.transforms import torchvision.transforms
from ..data.typing import DataLoader, TransformSequence from ..data.typing import TransformSequence
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -54,7 +54,8 @@ class Densenet(pl.LightningModule): ...@@ -54,7 +54,8 @@ class Densenet(pl.LightningModule):
applied on the input **before** it is fed into the network. applied on the input **before** it is fed into the network.
pretrained pretrained
If set to True, loads pretrained model weights during initialization, else trains a new model. If set to True, loads pretrained model weights during initialization,
else trains a new model.
""" """
def __init__( def __init__(
...@@ -107,7 +108,8 @@ class Densenet(pl.LightningModule): ...@@ -107,7 +108,8 @@ class Densenet(pl.LightningModule):
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning to restore your model. """Called by Lightning to restore your model.
If you saved something with on_save_checkpoint() this is your chance to restore this. If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters Parameters
---------- ----------
...@@ -161,50 +163,6 @@ class Densenet(pl.LightningModule): ...@@ -161,50 +163,6 @@ class Densenet(pl.LightningModule):
) )
self.normalizer = make_z_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def balance_losses_by_class(
self,
train_dataloader: DataLoader,
valid_dataloader: dict[str, DataLoader],
):
"""Reweights loss weights if possible.
Parameters
----------
train_dataloader
The data loader to use for training
valid_dataloader
The data loaders to use for each of the validation sets
Raises
------
RuntimeError
If train or validation losses are not of type
:py:class:`torch.nn.BCEWithLogitsLoss`.
"""
from .loss_weights import get_label_weights
if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training loss.")
weights = get_label_weights(train_dataloader)
self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
else:
raise RuntimeError(
"Training loss is not BCEWithLogitsLoss - dunno how to balance"
)
if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation loss.")
weights = get_label_weights(valid_dataloader)
self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
else:
raise RuntimeError(
"Validation loss is not BCEWithLogitsLoss - dunno how to balance"
)
def training_step(self, batch, _): def training_step(self, batch, _):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["label"]
......
...@@ -7,10 +7,12 @@ import logging ...@@ -7,10 +7,12 @@ import logging
import torch import torch
import torch.utils.data import torch.utils.data
from ..data.typing import DataLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_label_weights( def _get_label_weights(
dataloader: torch.utils.data.DataLoader, dataloader: torch.utils.data.DataLoader,
) -> torch.Tensor: ) -> torch.Tensor:
"""Computes the weights of each class of a DataLoader. """Computes the weights of each class of a DataLoader.
...@@ -68,3 +70,23 @@ def get_label_weights( ...@@ -68,3 +70,23 @@ def get_label_weights(
) )
return positive_weights return positive_weights
def make_balanced_bcewithlogitsloss(
dataloader: DataLoader,
) -> torch.nn.BCEWithLogitsLoss:
"""Returns a balanced binary-cross-entropy loss.
The loss is weighted using the ratio between positives and total examples
available.
Returns
-------
loss
An instance of the weighted loss
"""
weights = _get_label_weights(dataloader)
return torch.nn.BCEWithLogitsLoss(pos_weight=weights)
...@@ -13,7 +13,7 @@ import torch.optim.optimizer ...@@ -13,7 +13,7 @@ import torch.optim.optimizer
import torch.utils.data import torch.utils.data
import torchvision.transforms import torchvision.transforms
from ..data.typing import DataLoader, TransformSequence from ..data.typing import TransformSequence
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -192,7 +192,8 @@ class Pasa(pl.LightningModule): ...@@ -192,7 +192,8 @@ class Pasa(pl.LightningModule):
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Called by Lightning to restore your model. """Called by Lightning to restore your model.
If you saved something with on_save_checkpoint() this is your chance to restore this. If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters Parameters
---------- ----------
...@@ -232,50 +233,6 @@ class Pasa(pl.LightningModule): ...@@ -232,50 +233,6 @@ class Pasa(pl.LightningModule):
) )
self.normalizer = make_z_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def balance_losses_by_class(
self,
train_dataloader: DataLoader,
valid_dataloader: dict[str, DataLoader],
):
"""Reweights loss weights if possible.
Parameters
----------
train_dataloader
The data loader to use for training
valid_dataloader
The data loaders to use for each of the validation sets
Raises
------
RuntimeError
If train or validation losses are not of type
:py:class:`torch.nn.BCEWithLogitsLoss`.
"""
from .loss_weights import get_label_weights
if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training loss.")
weights = get_label_weights(train_dataloader)
self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
else:
raise RuntimeError(
"Training loss is not BCEWithLogitsLoss - dunno how to balance"
)
if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation loss.")
weights = get_label_weights(valid_dataloader)
self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
else:
raise RuntimeError(
"Validation loss is not BCEWithLogitsLoss - dunno how to balance"
)
def training_step(self, batch, _): def training_step(self, batch, _):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["label"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment