diff --git a/score_reg_wacv/losses/loss_fns.py b/score_reg_wacv/losses/loss_fns.py
index 3eefd56cc9bb1749b265cdbabb08d0eba45777ad..e4719a50172619bf253b40886748e02fd259c217 100644
--- a/score_reg_wacv/losses/loss_fns.py
+++ b/score_reg_wacv/losses/loss_fns.py
@@ -84,7 +84,7 @@ class IntraDemogLoss(nn.Module):
         cat_scores = torch.cat((genuine_scores, impostor_scores))
         cat_labels = torch.cat((genuine_labels, imposter_labels))
 
-        output = self.loss_fn(new_scores[:, 0], new_labels)
+        output = self.loss_fn(cat_scores[:, 0], cat_labels)
         return output