From 28a261b09a8dcfe27695969becf76ff5b3473499 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 1 May 2024 17:38:14 +0200
Subject: [PATCH] [model] Fix detection of binary targets in loss balancing

---
 src/mednet/models/loss_weights.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py
index f91e1eda..d04bdfea 100644
--- a/src/mednet/models/loss_weights.py
+++ b/src/mednet/models/loss_weights.py
@@ -162,7 +162,7 @@ def get_positive_weights(
 
     targets_tensor = torch.tensor(targets_list)
 
-    if len(list(targets_tensor.shape)) == 1:
+    if targets_tensor.shape[0] == 1:
         logger.info("Computing positive weights assuming binary labels.")
         positive_weights = compute_binary_weights(targets_tensor)
     else:
-- 
GitLab