From 0f413ac4a0f9585636552d26a82e05d2de4b518f Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 13 May 2020 15:55:39 +0200
Subject: [PATCH] [engine.ssltrainer] Validation does not compute SSL loss,
 just standard model performance

---
 bob/ip/binseg/engine/ssltrainer.py | 29 ++++++++++-------------------
 1 file changed, 10 insertions(+), 19 deletions(-)

diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 6d378268..4a8cca16 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -376,6 +376,7 @@ def run(
                 scheduler.step()
 
             # calculates the validation loss if necessary
+            # note: validation does not comprise "unlabelled" losses
             valid_losses = None
             if valid_loader is not None:
 
@@ -385,8 +386,7 @@ def run(
                     for samples in tqdm(
                         valid_loader, desc="valid", leave=False, disable=None
                     ):
-
-                        # labelled
+                        # data forwarding on the existing network
                         images = samples[1].to(
                             device=device,
                             non_blocking=torch.cuda.is_available(),
@@ -395,25 +395,16 @@ def run(
                             device=device,
                             non_blocking=torch.cuda.is_available(),
                         )
-                        unlabelled_images = samples[4].to(
-                            device=device,
-                            non_blocking=torch.cuda.is_available(),
-                        )
-                        # labelled outputs
+                        masks = None
+                        if len(samples) == 4:
+                            masks = samples[-1].to(
+                                device=device,
+                                non_blocking=torch.cuda.is_available(),
+                            )
+
                         outputs = model(images)
-                        unlabelled_outputs = model(unlabelled_images)
-                        # guessed unlabelled outputs
-                        unlabelled_ground_truths = guess_labels(
-                            unlabelled_images, model
-                        )
-                        loss, ll, ul = criterion(
-                            outputs,
-                            ground_truths,
-                            unlabelled_outputs,
-                            unlabelled_ground_truths,
-                            ramp_up_factor,
-                        )
 
+                        loss = criterion(outputs, ground_truths, masks)
                         valid_losses.update(loss)
 
             if checkpoint_period and (epoch % checkpoint_period == 0):
-- 
GitLab