From 3da537d096f9bcc3220a9d3ef48f9db7a2cf521a Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Sat, 1 Jul 2023 19:07:37 +0200 Subject: [PATCH] [data.dataset] Make function private --- src/ptbench/data/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index c85fbcb3..af8a7364 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -10,7 +10,7 @@ import torch.utils.data logger = logging.getLogger(__name__) -def get_positive_weights(dataset): +def _get_positive_weights(dataset): """Compute the positive weights of each class of the dataset to balance the BCEWithLogitsLoss criterion. @@ -84,7 +84,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): # Redefine a weighted criterion if possible if isinstance(criterion, torch.nn.BCEWithLogitsLoss): - positive_weights = get_positive_weights(train_dataset) + positive_weights = _get_positive_weights(train_dataset) model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) else: logger.warning("Weighted criterion not supported") @@ -95,7 +95,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) or criterion_valid is None ): - positive_weights = get_positive_weights(validation_dataset) + positive_weights = _get_positive_weights(validation_dataset) model.hparams.criterion_valid = BCEWithLogitsLoss( pos_weight=positive_weights ) -- GitLab