From 3da537d096f9bcc3220a9d3ef48f9db7a2cf521a Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Sat, 1 Jul 2023 19:07:37 +0200
Subject: [PATCH] [data.dataset] Make function private

---
 src/ptbench/data/dataset.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
index c85fbcb3..af8a7364 100644
--- a/src/ptbench/data/dataset.py
+++ b/src/ptbench/data/dataset.py
@@ -10,7 +10,7 @@ import torch.utils.data
 logger = logging.getLogger(__name__)
 
 
-def get_positive_weights(dataset):
+def _get_positive_weights(dataset):
     """Compute the positive weights of each class of the dataset to balance the
     BCEWithLogitsLoss criterion.
 
@@ -84,7 +84,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
 
     # Redefine a weighted criterion if possible
     if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
-        positive_weights = get_positive_weights(train_dataset)
+        positive_weights = _get_positive_weights(train_dataset)
         model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
     else:
         logger.warning("Weighted criterion not supported")
@@ -95,7 +95,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
             isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
             or criterion_valid is None
         ):
-            positive_weights = get_positive_weights(validation_dataset)
+            positive_weights = _get_positive_weights(validation_dataset)
             model.hparams.criterion_valid = BCEWithLogitsLoss(
                 pos_weight=positive_weights
             )
-- 
GitLab