Skip to content
Snippets Groups Projects

Reviewed DataModule design+docs+types

Merged André Anjos requested to merge add-datamodule-andre into add-datamodule
All threads resolved!
1 file
+ 3
3
Compare changes
  • Side-by-side
  • Inline
@@ -10,7 +10,7 @@ import torch.utils.data
logger = logging.getLogger(__name__)
def get_positive_weights(dataset):
def _get_positive_weights(dataset):
"""Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion.
@@ -84,7 +84,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
# Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = get_positive_weights(train_dataset)
positive_weights = _get_positive_weights(train_dataset)
model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else:
logger.warning("Weighted criterion not supported")
@@ -95,7 +95,7 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
or criterion_valid is None
):
positive_weights = get_positive_weights(validation_dataset)
positive_weights = _get_positive_weights(validation_dataset)
model.hparams.criterion_valid = BCEWithLogitsLoss(
pos_weight=positive_weights
)
Loading