From 0f413ac4a0f9585636552d26a82e05d2de4b518f Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 13 May 2020 15:55:39 +0200 Subject: [PATCH] [engine.ssltrainer] Validation does not compute SSL loss, just standard model performance --- bob/ip/binseg/engine/ssltrainer.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 6d378268..4a8cca16 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -376,6 +376,7 @@ def run( scheduler.step() # calculates the validation loss if necessary + # note: validation does not comprise "unlabelled" losses valid_losses = None if valid_loader is not None: @@ -385,8 +386,7 @@ def run( for samples in tqdm( valid_loader, desc="valid", leave=False, disable=None ): - - # labelled + # data forwarding on the existing network images = samples[1].to( device=device, non_blocking=torch.cuda.is_available(), @@ -395,25 +395,16 @@ def run( device=device, non_blocking=torch.cuda.is_available(), ) - unlabelled_images = samples[4].to( - device=device, - non_blocking=torch.cuda.is_available(), - ) - # labelled outputs + masks = None + if len(samples) == 4: + masks = samples[-1].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) + outputs = model(images) - unlabelled_outputs = model(unlabelled_images) - # guessed unlabelled outputs - unlabelled_ground_truths = guess_labels( - unlabelled_images, model - ) - loss, ll, ul = criterion( - outputs, - ground_truths, - unlabelled_outputs, - unlabelled_ground_truths, - ramp_up_factor, - ) + loss = criterion(outputs, ground_truths, masks) valid_losses.update(loss) if checkpoint_period and (epoch % checkpoint_period == 0): -- GitLab