diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index c85fbcb3d221a2e73125d53785e0fff8cbdaa464..af8a736454608961535caee8cddb432a42199410 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -10,7 +10,7 @@ import torch.utils.data logger = logging.getLogger(__name__) -def get_positive_weights(dataset): +def _get_positive_weights(dataset): """Compute the positive weights of each class of the dataset to balance the BCEWithLogitsLoss criterion. @@ -84,7 +84,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): # Redefine a weighted criterion if possible if isinstance(criterion, torch.nn.BCEWithLogitsLoss): - positive_weights = get_positive_weights(train_dataset) + positive_weights = _get_positive_weights(train_dataset) model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) else: logger.warning("Weighted criterion not supported") @@ -95,7 +95,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) or criterion_valid is None ): - positive_weights = get_positive_weights(validation_dataset) + positive_weights = _get_positive_weights(validation_dataset) model.hparams.criterion_valid = BCEWithLogitsLoss( pos_weight=positive_weights )