Skip to content
Snippets Groups Projects
Commit 51b5d6f7 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[train] Use loss balancing instead of sampler balancing

This currently only supports a single validation dataloader
parent 5fae86b8
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
......@@ -481,10 +481,6 @@ class ConcatDataModule(lightning.LightningDataModule):
for CPU memory. Sufficient CPU memory must be available before you set
this attribute to ``True``. It is typically useful for relatively small
datasets.
balance_sampler_by_class
If set, then modifies the random sampler used during training and
validation to balance sample picking probability, making sample
across classes **and** datasets equitable.
batch_size
Number of samples in every **training** batch (this parameter affects
memory requirements for the network). If the number of samples in the
......@@ -529,7 +525,6 @@ class ConcatDataModule(lightning.LightningDataModule):
database_name: str = "",
split_name: str = "",
cache_samples: bool = False,
balance_sampler_by_class: bool = False,
batch_size: int = 1,
batch_chunk_count: int = 1,
drop_incomplete_batch: bool = False,
......@@ -552,7 +547,6 @@ class ConcatDataModule(lightning.LightningDataModule):
self.cache_samples = cache_samples
self._train_sampler = None
self.balance_sampler_by_class = balance_sampler_by_class
self._model_transforms: list[Transform] | None = None
......@@ -667,40 +661,6 @@ class ConcatDataModule(lightning.LightningDataModule):
)
self._datasets = {}
@property
def balance_sampler_by_class(self) -> bool:
"""Whether to balance samples across labels/datasets.
If set, then modifies the random sampler used during training
and validation to balance sample picking probability, making
sample across classes **and** datasets equitable.
.. warning::
This method does **NOT** balance the sampler per dataset, in case
multiple datasets compose the same training set. It only balances
samples acording to their ground-truth (labels). If you'd like to
have samples balanced per dataset, then implement your own data
module inheriting from this one.
Returns
-------
bool
True if self._train_sample is set, else False.
"""
return self._train_sampler is not None
@balance_sampler_by_class.setter
def balance_sampler_by_class(self, value: bool):
if value:
if "train" not in self._datasets:
self._setup_dataset("train")
self._train_sampler = _make_balanced_random_sampler(
self._datasets["train"],
)
else:
self._train_sampler = None
def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
"""Coherently set the batch-chunk-size after validation.
......
......@@ -34,7 +34,6 @@ def _get_label_weights(
torch.Tensor
The positive weight of each class in the dataset given as input.
"""
targets = torch.tensor(
[sample for batch in dataloader for sample in batch[1]["label"]],
)
......
......@@ -13,6 +13,7 @@ import torch.utils.data
import torchvision.transforms
from ..data.typing import TransformSequence
from .loss_weights import _get_label_weights
from .typing import Checkpoint
logger = logging.getLogger(__name__)
......@@ -141,3 +142,37 @@ class Model(pl.LightningModule):
self.parameters(),
**self._optimizer_arguments,
)
def balance_losses(self, datamodule) -> None:
"""Balance the loss based on the distribution of targets in the datamodule, if the loss function supports it.
Parameters
----------
datamodule
Instance of a datamodule.
"""
logger.info(f"Balancing training loss function {self._train_loss}.")
try:
getattr(self._train_loss, "pos_weight")
except AttributeError:
logger.warning(
"Training loss does not posess a 'pos_weight' attribute and will not be balanced."
)
else:
train_weights = _get_label_weights(datamodule.train_dataloader())
setattr(self._train_loss, "pos_weight", train_weights)
logger.info(
f"Balancing validation loss function {self._validation_loss}."
)
try:
getattr(self._validation_loss, "pos_weight")
except AttributeError:
logger.warning(
"Validation loss does not posess a 'pos_weight' attribute and will not be balanced."
)
else:
validation_weights = _get_label_weights(
datamodule.val_dataloader()["validation"]
)
setattr(self._validation_loss, "pos_weight", validation_weights)
......@@ -296,10 +296,8 @@ def train(
# of class examples available in the training set. Also affects the
# validation loss if a validation set is available on the DataModule.
if balance_classes:
logger.info("Applying DataModule train sampler balancing...")
datamodule.balance_sampler_by_class = True
# logger.info("Applying train/valid loss balancing...")
# model.balance_losses_by_class(datamodule)
logger.info("Applying train/valid loss balancing...")
model.balance_losses(datamodule)
else:
logger.info(
"Skipping sample class/dataset ownership balancing on user request",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment