diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index c73bb27932c52e2061134a7f007256c7f6292161..2de54fdf96d95ac4f553c43b548d168dcf2bace4 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -481,10 +481,6 @@ class ConcatDataModule(lightning.LightningDataModule): for CPU memory. Sufficient CPU memory must be available before you set this attribute to ``True``. It is typically useful for relatively small datasets. - balance_sampler_by_class - If set, then modifies the random sampler used during training and - validation to balance sample picking probability, making sample - across classes **and** datasets equitable. batch_size Number of samples in every **training** batch (this parameter affects memory requirements for the network). If the number of samples in the @@ -529,7 +525,6 @@ class ConcatDataModule(lightning.LightningDataModule): database_name: str = "", split_name: str = "", cache_samples: bool = False, - balance_sampler_by_class: bool = False, batch_size: int = 1, batch_chunk_count: int = 1, drop_incomplete_batch: bool = False, @@ -552,7 +547,6 @@ class ConcatDataModule(lightning.LightningDataModule): self.cache_samples = cache_samples self._train_sampler = None - self.balance_sampler_by_class = balance_sampler_by_class self._model_transforms: list[Transform] | None = None @@ -667,40 +661,6 @@ class ConcatDataModule(lightning.LightningDataModule): ) self._datasets = {} - @property - def balance_sampler_by_class(self) -> bool: - """Whether to balance samples across labels/datasets. - - If set, then modifies the random sampler used during training - and validation to balance sample picking probability, making - sample across classes **and** datasets equitable. - - .. warning:: - - This method does **NOT** balance the sampler per dataset, in case - multiple datasets compose the same training set. It only balances - samples acording to their ground-truth (labels). If you'd like to - have samples balanced per dataset, then implement your own data - module inheriting from this one. - - Returns - ------- - bool - True if self._train_sample is set, else False. - """ - return self._train_sampler is not None - - @balance_sampler_by_class.setter - def balance_sampler_by_class(self, value: bool): - if value: - if "train" not in self._datasets: - self._setup_dataset("train") - self._train_sampler = _make_balanced_random_sampler( - self._datasets["train"], - ) - else: - self._train_sampler = None - def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None: """Coherently set the batch-chunk-size after validation. diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py index bf965790cade10d68e19c0bf372c9fa7bf4d5409..32f1e33e459b0f7b08a9b4f13eb72e05ee65492d 100644 --- a/src/mednet/models/loss_weights.py +++ b/src/mednet/models/loss_weights.py @@ -34,7 +34,6 @@ def _get_label_weights( torch.Tensor The positive weight of each class in the dataset given as input. """ - targets = torch.tensor( [sample for batch in dataloader for sample in batch[1]["label"]], ) diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index 50e314bb3380246f738a44b720b851ba97d9502c..a0b3701eae4376481a3875e5c53a91ed6976fe05 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -13,6 +13,7 @@ import torch.utils.data import torchvision.transforms from ..data.typing import TransformSequence +from .loss_weights import _get_label_weights from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -141,3 +142,37 @@ class Model(pl.LightningModule): self.parameters(), **self._optimizer_arguments, ) + + def balance_losses(self, datamodule) -> None: + """Balance the loss based on the distribution of targets in the datamodule, if the loss function supports it. + + Parameters + ---------- + datamodule + Instance of a datamodule. + """ + logger.info(f"Balancing training loss function {self._train_loss}.") + try: + getattr(self._train_loss, "pos_weight") + except AttributeError: + logger.warning( + "Training loss does not posess a 'pos_weight' attribute and will not be balanced." + ) + else: + train_weights = _get_label_weights(datamodule.train_dataloader()) + setattr(self._train_loss, "pos_weight", train_weights) + + logger.info( + f"Balancing validation loss function {self._validation_loss}." + ) + try: + getattr(self._validation_loss, "pos_weight") + except AttributeError: + logger.warning( + "Validation loss does not posess a 'pos_weight' attribute and will not be balanced." + ) + else: + validation_weights = _get_label_weights( + datamodule.val_dataloader()["validation"] + ) + setattr(self._validation_loss, "pos_weight", validation_weights) diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index 68a4e7e721f412fbed430774491147f9c1130577..e83ae359b0be523565f669c57bf1de41520ccac7 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -296,10 +296,8 @@ def train( # of class examples available in the training set. Also affects the # validation loss if a validation set is available on the DataModule. if balance_classes: - logger.info("Applying DataModule train sampler balancing...") - datamodule.balance_sampler_by_class = True - # logger.info("Applying train/valid loss balancing...") - # model.balance_losses_by_class(datamodule) + logger.info("Applying train/valid loss balancing...") + model.balance_losses(datamodule) else: logger.info( "Skipping sample class/dataset ownership balancing on user request",