diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py index f91e1eda1f563f5e1f53922576605704f4071e64..d04bdfea67ea391b7a18a066b6aa2b563a345aea 100644 --- a/src/mednet/models/loss_weights.py +++ b/src/mednet/models/loss_weights.py @@ -162,7 +162,7 @@ def get_positive_weights( targets_tensor = torch.tensor(targets_list) - if len(list(targets_tensor.shape)) == 1: + if targets_tensor.shape[0] == 1: logger.info("Computing positive weights assuming binary labels.") positive_weights = compute_binary_weights(targets_tensor) else: