diff --git a/score_reg_wacv/losses/loss_fns.py b/score_reg_wacv/losses/loss_fns.py index 3eefd56cc9bb1749b265cdbabb08d0eba45777ad..e4719a50172619bf253b40886748e02fd259c217 100644 --- a/score_reg_wacv/losses/loss_fns.py +++ b/score_reg_wacv/losses/loss_fns.py @@ -84,7 +84,7 @@ class IntraDemogLoss(nn.Module): cat_scores = torch.cat((genuine_scores, impostor_scores)) cat_labels = torch.cat((genuine_labels, imposter_labels)) - output = self.loss_fn(new_scores[:, 0], new_labels) + output = self.loss_fn(cat_scores[:, 0], cat_labels) return output