Skip to content
Snippets Groups Projects
Commit 1f0e3938 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[model] Fix balancing of multiclass targets

parent b32ebbe5
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
......@@ -3,86 +3,180 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
from collections import Counter
import torch
import torch.utils.data
from ..data.typing import DataLoader
logger = logging.getLogger(__name__)
def _get_label_weights(
dataloader: torch.utils.data.DataLoader,
) -> torch.Tensor:
"""Compute the weights of each class of a DataLoader.
def compute_binary_weights(targets):
"""Compute the positive weights when using binary targets.
This function inputs a pytorch DataLoader and computes the ratio between
number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance.
Parameters
----------
targets
A tensor of integer values of length n.
It returns a vector with weights (inverse counts) for each label.
Returns
-------
The positive weights per class.
"""
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
# Divide negatives by positives
return torch.tensor(
[class_sample_count[0] / class_sample_count[1]],
).reshape(-1)
def compute_multiclass_weights(targets):
"""Compute the positive weights when using exclusive, multiclass targets.
Parameters
----------
dataloader
A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key.
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
Returns
-------
torch.Tensor
The positive weight of each class in the dataset given as input.
The positive weights per class.
"""
targets = torch.tensor(
[sample for batch in dataloader for sample in batch[1]["label"]],
class_sample_count = torch.sum(targets, dim=1)
negative_class_sample_count = (
torch.full((targets.size()[0],), float(targets.size()[1]))
- class_sample_count
)
# Binary labels
if len(list(targets.shape)) == 1:
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
return negative_class_sample_count / (
class_sample_count + negative_class_sample_count
)
# Divide negatives by positives
positive_weights = torch.tensor(
[class_sample_count[0] / class_sample_count[1]],
).reshape(-1)
# Multiclass labels
else:
class_sample_count = torch.sum(targets, dim=0)
negative_class_sample_count = (
torch.full((targets.size()[1],), float(targets.size()[0]))
- class_sample_count
)
def compute_non_exclusive_multiclass_weights(targets):
"""Compute the positive weights when using non-exclusive, multiclass targets.
positive_weights = negative_class_sample_count / (
class_sample_count + negative_class_sample_count
)
Parameters
----------
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
return positive_weights
Returns
-------
The positive weights per class.
"""
raise ValueError(
"Computing weights of multi-class, non-exclusive labels is not yet supported."
)
def make_balanced_bcewithlogitsloss(
dataloader: DataLoader,
) -> torch.nn.BCEWithLogitsLoss:
"""Return a balanced binary-cross-entropy loss.
def is_multicalss_exclusive(targets: torch.Tensor) -> bool:
"""Given a [C x n] tensor of integer targets, checks whether samples can only belong to a single class.
The loss is weighted using the ratio between positives and total examples
available.
Parameters
----------
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
Returns
-------
True if all samples belong to a single class, False otherwise (a sample can belong to multiple classes).
"""
max_counts = []
transposed_targets = torch.transpose(targets, 0, 1)
for t in transposed_targets:
filtered_list = [i for i in t.tolist() if i != 2]
counts = Counter(filtered_list)
max_counts.append(max(counts.values()))
if set(max_counts) == {1}:
return True
return False
def tensor_to_list(tensor) -> list[typing.Any]:
"""Convert a torch.Tensor to a list.
This is necessary, as torch.tolist returns an int when then tensor contains a single value.
Parameters
----------
tensor
The tensor to convert to a list.
Returns
-------
The tensor converted to a list.
"""
tensor = tensor.tolist()
if isinstance(tensor, int):
return [tensor]
return tensor
def get_positive_weights(
dataloader: torch.utils.data.DataLoader,
) -> torch.Tensor:
"""Compute the weights of each class of a DataLoader.
This function inputs a pytorch DataLoader and computes the ratio between
number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance.
It returns a vector with weights (inverse counts) for each label.
Parameters
----------
dataloader
The DataLoader to use to compute the BCE weights.
A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key.
Returns
-------
torch.nn.BCEWithLogitsLoss
An instance of the weighted loss.
The positive weight of each class in the dataset given as input.
"""
weights = _get_label_weights(dataloader)
return torch.nn.BCEWithLogitsLoss(pos_weight=weights)
from collections import defaultdict
targets = defaultdict(list)
for batch in dataloader:
for class_idx, class_targets in enumerate(batch[1]["label"]):
# Targets are either a single tensor (binary case) or a list of tensors (multilabel)
if isinstance(batch[1]["label"], list):
targets[class_idx].extend(tensor_to_list(class_targets))
else:
targets[0].extend(tensor_to_list(class_targets))
targets_list = []
for k in sorted(list(targets.keys())):
targets_list.append(targets[k])
targets_tensor = torch.tensor(targets_list)
if len(list(targets_tensor.shape)) == 1:
logger.info("Computing positive weights assuming binary labels.")
positive_weights = compute_binary_weights(targets_tensor)
else:
if is_multicalss_exclusive(targets_tensor):
logger.info(
"Computing positive weights assuming multiclass, exclusive labels."
)
positive_weights = compute_multiclass_weights(targets_tensor)
else:
logger.info(
"Computing positive weights assuming multiclass, non-exclusive labels."
)
positive_weights = compute_non_exclusive_multiclass_weights(
targets_tensor
)
return positive_weights
......@@ -13,7 +13,7 @@ import torch.utils.data
import torchvision.transforms
from ..data.typing import TransformSequence
from .loss_weights import _get_label_weights
from .loss_weights import get_positive_weights
from .typing import Checkpoint
logger = logging.getLogger(__name__)
......@@ -152,6 +152,7 @@ class Model(pl.LightningModule):
datamodule
Instance of a datamodule.
"""
logger.info(f"Balancing training loss function {self._train_loss}.")
try:
getattr(self._train_loss, "pos_weight")
......@@ -160,7 +161,7 @@ class Model(pl.LightningModule):
"Training loss does not posess a 'pos_weight' attribute and will not be balanced."
)
else:
train_weights = _get_label_weights(datamodule.train_dataloader())
train_weights = get_positive_weights(datamodule.train_dataloader())
setattr(self._train_loss, "pos_weight", train_weights)
logger.info(
......@@ -185,7 +186,7 @@ class Model(pl.LightningModule):
)
for val_dataset_key in datamodule_validation_keys:
validation_weights = _get_label_weights(
validation_weights = get_positive_weights(
datamodule.val_dataloader()[val_dataset_key]
)
new_validation_losses.append(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment