Skip to content
Snippets Groups Projects
Commit 48ff8cce authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.dataset] Make function private

parent c3a58f11
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -10,7 +10,7 @@ import torch.utils.data ...@@ -10,7 +10,7 @@ import torch.utils.data
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_positive_weights(dataset): def _get_positive_weights(dataset):
"""Compute the positive weights of each class of the dataset to balance the """Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion. BCEWithLogitsLoss criterion.
...@@ -84,7 +84,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): ...@@ -84,7 +84,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
# Redefine a weighted criterion if possible # Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss): if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = get_positive_weights(train_dataset) positive_weights = _get_positive_weights(train_dataset)
model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else: else:
logger.warning("Weighted criterion not supported") logger.warning("Weighted criterion not supported")
...@@ -95,7 +95,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): ...@@ -95,7 +95,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
or criterion_valid is None or criterion_valid is None
): ):
positive_weights = get_positive_weights(validation_dataset) positive_weights = _get_positive_weights(validation_dataset)
model.hparams.criterion_valid = BCEWithLogitsLoss( model.hparams.criterion_valid = BCEWithLogitsLoss(
pos_weight=positive_weights pos_weight=positive_weights
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment