diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py
index f91e1eda1f563f5e1f53922576605704f4071e64..d04bdfea67ea391b7a18a066b6aa2b563a345aea 100644
--- a/src/mednet/models/loss_weights.py
+++ b/src/mednet/models/loss_weights.py
@@ -162,7 +162,7 @@ def get_positive_weights(
 
     targets_tensor = torch.tensor(targets_list)
 
-    if len(list(targets_tensor.shape)) == 1:
+    if targets_tensor.shape[0] == 1:
         logger.info("Computing positive weights assuming binary labels.")
         positive_weights = compute_binary_weights(targets_tensor)
     else: