diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 6d378268a2c5aed5e54b34ea88e86c69afe4763b..4a8cca1661cc600886f2e5fa655ae5a26245b920 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -376,6 +376,7 @@ def run(
                 scheduler.step()
 
             # calculates the validation loss if necessary
+            # note: validation does not comprise "unlabelled" losses
             valid_losses = None
             if valid_loader is not None:
 
@@ -385,8 +386,7 @@ def run(
                     for samples in tqdm(
                         valid_loader, desc="valid", leave=False, disable=None
                     ):
-
-                        # labelled
+                        # data forwarding on the existing network
                         images = samples[1].to(
                             device=device,
                             non_blocking=torch.cuda.is_available(),
@@ -395,25 +395,16 @@ def run(
                             device=device,
                             non_blocking=torch.cuda.is_available(),
                         )
-                        unlabelled_images = samples[4].to(
-                            device=device,
-                            non_blocking=torch.cuda.is_available(),
-                        )
-                        # labelled outputs
+                        masks = None
+                        if len(samples) == 4:
+                            masks = samples[-1].to(
+                                device=device,
+                                non_blocking=torch.cuda.is_available(),
+                            )
+
                         outputs = model(images)
-                        unlabelled_outputs = model(unlabelled_images)
-                        # guessed unlabelled outputs
-                        unlabelled_ground_truths = guess_labels(
-                            unlabelled_images, model
-                        )
-                        loss, ll, ul = criterion(
-                            outputs,
-                            ground_truths,
-                            unlabelled_outputs,
-                            unlabelled_ground_truths,
-                            ramp_up_factor,
-                        )
 
+                        loss = criterion(outputs, ground_truths, masks)
                         valid_losses.update(loss)
 
             if checkpoint_period and (epoch % checkpoint_period == 0):