Skip to content
Snippets Groups Projects
Commit 3da537d0 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.dataset] Make function private

parent 4c65b4fb
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
...@@ -10,7 +10,7 @@ import torch.utils.data ...@@ -10,7 +10,7 @@ import torch.utils.data
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_positive_weights(dataset): def _get_positive_weights(dataset):
"""Compute the positive weights of each class of the dataset to balance the """Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion. BCEWithLogitsLoss criterion.
...@@ -84,7 +84,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): ...@@ -84,7 +84,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
# Redefine a weighted criterion if possible # Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss): if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = get_positive_weights(train_dataset) positive_weights = _get_positive_weights(train_dataset)
model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else: else:
logger.warning("Weighted criterion not supported") logger.warning("Weighted criterion not supported")
...@@ -95,7 +95,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): ...@@ -95,7 +95,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
or criterion_valid is None or criterion_valid is None
): ):
positive_weights = get_positive_weights(validation_dataset) positive_weights = _get_positive_weights(validation_dataset)
model.hparams.criterion_valid = BCEWithLogitsLoss( model.hparams.criterion_valid = BCEWithLogitsLoss(
pos_weight=positive_weights pos_weight=positive_weights
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment