diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py index 32f1e33e459b0f7b08a9b4f13eb72e05ee65492d..f91e1eda1f563f5e1f53922576605704f4071e64 100644 --- a/src/mednet/models/loss_weights.py +++ b/src/mednet/models/loss_weights.py @@ -3,86 +3,180 @@ # SPDX-License-Identifier: GPL-3.0-or-later import logging +import typing +from collections import Counter import torch import torch.utils.data -from ..data.typing import DataLoader - logger = logging.getLogger(__name__) -def _get_label_weights( - dataloader: torch.utils.data.DataLoader, -) -> torch.Tensor: - """Compute the weights of each class of a DataLoader. +def compute_binary_weights(targets): + """Compute the positive weights when using binary targets. - This function inputs a pytorch DataLoader and computes the ratio between - number of negative and positive samples (scalar). The weight can be used - to adjust minimisation criteria to in cases there is a huge data imbalance. + Parameters + ---------- + targets + A tensor of integer values of length n. - It returns a vector with weights (inverse counts) for each label. + Returns + ------- + The positive weights per class. + """ + class_sample_count = [ + float((targets == t).sum().item()) + for t in torch.unique(targets, sorted=True) + ] + + # Divide negatives by positives + return torch.tensor( + [class_sample_count[0] / class_sample_count[1]], + ).reshape(-1) + + +def compute_multiclass_weights(targets): + """Compute the positive weights when using exclusive, multiclass targets. Parameters ---------- - dataloader - A DataLoader from which to compute the positive weights. Entries must - be a dictionary which must contain a ``label`` key. + targets + A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples. Returns ------- - torch.Tensor - The positive weight of each class in the dataset given as input. + The positive weights per class. """ - targets = torch.tensor( - [sample for batch in dataloader for sample in batch[1]["label"]], + + class_sample_count = torch.sum(targets, dim=1) + negative_class_sample_count = ( + torch.full((targets.size()[0],), float(targets.size()[1])) + - class_sample_count ) - # Binary labels - if len(list(targets.shape)) == 1: - class_sample_count = [ - float((targets == t).sum().item()) - for t in torch.unique(targets, sorted=True) - ] + return negative_class_sample_count / ( + class_sample_count + negative_class_sample_count + ) - # Divide negatives by positives - positive_weights = torch.tensor( - [class_sample_count[0] / class_sample_count[1]], - ).reshape(-1) - # Multiclass labels - else: - class_sample_count = torch.sum(targets, dim=0) - negative_class_sample_count = ( - torch.full((targets.size()[1],), float(targets.size()[0])) - - class_sample_count - ) +def compute_non_exclusive_multiclass_weights(targets): + """Compute the positive weights when using non-exclusive, multiclass targets. - positive_weights = negative_class_sample_count / ( - class_sample_count + negative_class_sample_count - ) + Parameters + ---------- + targets + A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples. - return positive_weights + Returns + ------- + The positive weights per class. + """ + raise ValueError( + "Computing weights of multi-class, non-exclusive labels is not yet supported." + ) -def make_balanced_bcewithlogitsloss( - dataloader: DataLoader, -) -> torch.nn.BCEWithLogitsLoss: - """Return a balanced binary-cross-entropy loss. +def is_multicalss_exclusive(targets: torch.Tensor) -> bool: + """Given a [C x n] tensor of integer targets, checks whether samples can only belong to a single class. - The loss is weighted using the ratio between positives and total examples - available. + Parameters + ---------- + targets + A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples. + + Returns + ------- + True if all samples belong to a single class, False otherwise (a sample can belong to multiple classes). + """ + max_counts = [] + transposed_targets = torch.transpose(targets, 0, 1) + for t in transposed_targets: + filtered_list = [i for i in t.tolist() if i != 2] + counts = Counter(filtered_list) + max_counts.append(max(counts.values())) + + if set(max_counts) == {1}: + return True + + return False + + +def tensor_to_list(tensor) -> list[typing.Any]: + """Convert a torch.Tensor to a list. + + This is necessary, as torch.tolist returns an int when then tensor contains a single value. + + Parameters + ---------- + tensor + The tensor to convert to a list. + + Returns + ------- + The tensor converted to a list. + """ + + tensor = tensor.tolist() + if isinstance(tensor, int): + return [tensor] + return tensor + + +def get_positive_weights( + dataloader: torch.utils.data.DataLoader, +) -> torch.Tensor: + """Compute the weights of each class of a DataLoader. + + This function inputs a pytorch DataLoader and computes the ratio between + number of negative and positive samples (scalar). The weight can be used + to adjust minimisation criteria to in cases there is a huge data imbalance. + + It returns a vector with weights (inverse counts) for each label. Parameters ---------- dataloader - The DataLoader to use to compute the BCE weights. + A DataLoader from which to compute the positive weights. Entries must + be a dictionary which must contain a ``label`` key. Returns ------- - torch.nn.BCEWithLogitsLoss - An instance of the weighted loss. + The positive weight of each class in the dataset given as input. """ - weights = _get_label_weights(dataloader) - return torch.nn.BCEWithLogitsLoss(pos_weight=weights) + from collections import defaultdict + + targets = defaultdict(list) + + for batch in dataloader: + for class_idx, class_targets in enumerate(batch[1]["label"]): + # Targets are either a single tensor (binary case) or a list of tensors (multilabel) + if isinstance(batch[1]["label"], list): + targets[class_idx].extend(tensor_to_list(class_targets)) + else: + targets[0].extend(tensor_to_list(class_targets)) + + targets_list = [] + for k in sorted(list(targets.keys())): + targets_list.append(targets[k]) + + targets_tensor = torch.tensor(targets_list) + + if len(list(targets_tensor.shape)) == 1: + logger.info("Computing positive weights assuming binary labels.") + positive_weights = compute_binary_weights(targets_tensor) + else: + if is_multicalss_exclusive(targets_tensor): + logger.info( + "Computing positive weights assuming multiclass, exclusive labels." + ) + positive_weights = compute_multiclass_weights(targets_tensor) + else: + logger.info( + "Computing positive weights assuming multiclass, non-exclusive labels." + ) + positive_weights = compute_non_exclusive_multiclass_weights( + targets_tensor + ) + + return positive_weights diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index 155ff19c0c0887b9a0504b414eb93242b7b0ff63..e5a13dc8336aa28335bba58d26e98d0814ac9ed8 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -13,7 +13,7 @@ import torch.utils.data import torchvision.transforms from ..data.typing import TransformSequence -from .loss_weights import _get_label_weights +from .loss_weights import get_positive_weights from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -152,6 +152,7 @@ class Model(pl.LightningModule): datamodule Instance of a datamodule. """ + logger.info(f"Balancing training loss function {self._train_loss}.") try: getattr(self._train_loss, "pos_weight") @@ -160,7 +161,7 @@ class Model(pl.LightningModule): "Training loss does not posess a 'pos_weight' attribute and will not be balanced." ) else: - train_weights = _get_label_weights(datamodule.train_dataloader()) + train_weights = get_positive_weights(datamodule.train_dataloader()) setattr(self._train_loss, "pos_weight", train_weights) logger.info( @@ -185,7 +186,7 @@ class Model(pl.LightningModule): ) for val_dataset_key in datamodule_validation_keys: - validation_weights = _get_label_weights( + validation_weights = get_positive_weights( datamodule.val_dataloader()[val_dataset_key] ) new_validation_losses.append(