Skip to content
Snippets Groups Projects
Commit 1f0e3938 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[model] Fix balancing of multiclass targets

parent b32ebbe5
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
...@@ -3,86 +3,180 @@ ...@@ -3,86 +3,180 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import logging import logging
import typing
from collections import Counter
import torch import torch
import torch.utils.data import torch.utils.data
from ..data.typing import DataLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _get_label_weights( def compute_binary_weights(targets):
dataloader: torch.utils.data.DataLoader, """Compute the positive weights when using binary targets.
) -> torch.Tensor:
"""Compute the weights of each class of a DataLoader.
This function inputs a pytorch DataLoader and computes the ratio between Parameters
number of negative and positive samples (scalar). The weight can be used ----------
to adjust minimisation criteria to in cases there is a huge data imbalance. targets
A tensor of integer values of length n.
It returns a vector with weights (inverse counts) for each label. Returns
-------
The positive weights per class.
"""
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
# Divide negatives by positives
return torch.tensor(
[class_sample_count[0] / class_sample_count[1]],
).reshape(-1)
def compute_multiclass_weights(targets):
"""Compute the positive weights when using exclusive, multiclass targets.
Parameters Parameters
---------- ----------
dataloader targets
A DataLoader from which to compute the positive weights. Entries must A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
be a dictionary which must contain a ``label`` key.
Returns Returns
------- -------
torch.Tensor The positive weights per class.
The positive weight of each class in the dataset given as input.
""" """
targets = torch.tensor(
[sample for batch in dataloader for sample in batch[1]["label"]], class_sample_count = torch.sum(targets, dim=1)
negative_class_sample_count = (
torch.full((targets.size()[0],), float(targets.size()[1]))
- class_sample_count
) )
# Binary labels return negative_class_sample_count / (
if len(list(targets.shape)) == 1: class_sample_count + negative_class_sample_count
class_sample_count = [ )
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
# Divide negatives by positives
positive_weights = torch.tensor(
[class_sample_count[0] / class_sample_count[1]],
).reshape(-1)
# Multiclass labels def compute_non_exclusive_multiclass_weights(targets):
else: """Compute the positive weights when using non-exclusive, multiclass targets.
class_sample_count = torch.sum(targets, dim=0)
negative_class_sample_count = (
torch.full((targets.size()[1],), float(targets.size()[0]))
- class_sample_count
)
positive_weights = negative_class_sample_count / ( Parameters
class_sample_count + negative_class_sample_count ----------
) targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
return positive_weights Returns
-------
The positive weights per class.
"""
raise ValueError(
"Computing weights of multi-class, non-exclusive labels is not yet supported."
)
def make_balanced_bcewithlogitsloss( def is_multicalss_exclusive(targets: torch.Tensor) -> bool:
dataloader: DataLoader, """Given a [C x n] tensor of integer targets, checks whether samples can only belong to a single class.
) -> torch.nn.BCEWithLogitsLoss:
"""Return a balanced binary-cross-entropy loss.
The loss is weighted using the ratio between positives and total examples Parameters
available. ----------
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
Returns
-------
True if all samples belong to a single class, False otherwise (a sample can belong to multiple classes).
"""
max_counts = []
transposed_targets = torch.transpose(targets, 0, 1)
for t in transposed_targets:
filtered_list = [i for i in t.tolist() if i != 2]
counts = Counter(filtered_list)
max_counts.append(max(counts.values()))
if set(max_counts) == {1}:
return True
return False
def tensor_to_list(tensor) -> list[typing.Any]:
"""Convert a torch.Tensor to a list.
This is necessary, as torch.tolist returns an int when then tensor contains a single value.
Parameters
----------
tensor
The tensor to convert to a list.
Returns
-------
The tensor converted to a list.
"""
tensor = tensor.tolist()
if isinstance(tensor, int):
return [tensor]
return tensor
def get_positive_weights(
dataloader: torch.utils.data.DataLoader,
) -> torch.Tensor:
"""Compute the weights of each class of a DataLoader.
This function inputs a pytorch DataLoader and computes the ratio between
number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance.
It returns a vector with weights (inverse counts) for each label.
Parameters Parameters
---------- ----------
dataloader dataloader
The DataLoader to use to compute the BCE weights. A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key.
Returns Returns
------- -------
torch.nn.BCEWithLogitsLoss The positive weight of each class in the dataset given as input.
An instance of the weighted loss.
""" """
weights = _get_label_weights(dataloader) from collections import defaultdict
return torch.nn.BCEWithLogitsLoss(pos_weight=weights)
targets = defaultdict(list)
for batch in dataloader:
for class_idx, class_targets in enumerate(batch[1]["label"]):
# Targets are either a single tensor (binary case) or a list of tensors (multilabel)
if isinstance(batch[1]["label"], list):
targets[class_idx].extend(tensor_to_list(class_targets))
else:
targets[0].extend(tensor_to_list(class_targets))
targets_list = []
for k in sorted(list(targets.keys())):
targets_list.append(targets[k])
targets_tensor = torch.tensor(targets_list)
if len(list(targets_tensor.shape)) == 1:
logger.info("Computing positive weights assuming binary labels.")
positive_weights = compute_binary_weights(targets_tensor)
else:
if is_multicalss_exclusive(targets_tensor):
logger.info(
"Computing positive weights assuming multiclass, exclusive labels."
)
positive_weights = compute_multiclass_weights(targets_tensor)
else:
logger.info(
"Computing positive weights assuming multiclass, non-exclusive labels."
)
positive_weights = compute_non_exclusive_multiclass_weights(
targets_tensor
)
return positive_weights
...@@ -13,7 +13,7 @@ import torch.utils.data ...@@ -13,7 +13,7 @@ import torch.utils.data
import torchvision.transforms import torchvision.transforms
from ..data.typing import TransformSequence from ..data.typing import TransformSequence
from .loss_weights import _get_label_weights from .loss_weights import get_positive_weights
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -152,6 +152,7 @@ class Model(pl.LightningModule): ...@@ -152,6 +152,7 @@ class Model(pl.LightningModule):
datamodule datamodule
Instance of a datamodule. Instance of a datamodule.
""" """
logger.info(f"Balancing training loss function {self._train_loss}.") logger.info(f"Balancing training loss function {self._train_loss}.")
try: try:
getattr(self._train_loss, "pos_weight") getattr(self._train_loss, "pos_weight")
...@@ -160,7 +161,7 @@ class Model(pl.LightningModule): ...@@ -160,7 +161,7 @@ class Model(pl.LightningModule):
"Training loss does not posess a 'pos_weight' attribute and will not be balanced." "Training loss does not posess a 'pos_weight' attribute and will not be balanced."
) )
else: else:
train_weights = _get_label_weights(datamodule.train_dataloader()) train_weights = get_positive_weights(datamodule.train_dataloader())
setattr(self._train_loss, "pos_weight", train_weights) setattr(self._train_loss, "pos_weight", train_weights)
logger.info( logger.info(
...@@ -185,7 +186,7 @@ class Model(pl.LightningModule): ...@@ -185,7 +186,7 @@ class Model(pl.LightningModule):
) )
for val_dataset_key in datamodule_validation_keys: for val_dataset_key in datamodule_validation_keys:
validation_weights = _get_label_weights( validation_weights = get_positive_weights(
datamodule.val_dataloader()[val_dataset_key] datamodule.val_dataloader()[val_dataset_key]
) )
new_validation_losses.append( new_validation_losses.append(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment