From 28a261b09a8dcfe27695969becf76ff5b3473499 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 1 May 2024 17:38:14 +0200 Subject: [PATCH] [model] Fix detection of binary targets in loss balancing --- src/mednet/models/loss_weights.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py index f91e1eda..d04bdfea 100644 --- a/src/mednet/models/loss_weights.py +++ b/src/mednet/models/loss_weights.py @@ -162,7 +162,7 @@ def get_positive_weights( targets_tensor = torch.tensor(targets_list) - if len(list(targets_tensor.shape)) == 1: + if targets_tensor.shape[0] == 1: logger.info("Computing positive weights assuming binary labels.") positive_weights = compute_binary_weights(targets_tensor) else: -- GitLab