diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 6d378268a2c5aed5e54b34ea88e86c69afe4763b..4a8cca1661cc600886f2e5fa655ae5a26245b920 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -376,6 +376,7 @@ def run( scheduler.step() # calculates the validation loss if necessary + # note: validation does not comprise "unlabelled" losses valid_losses = None if valid_loader is not None: @@ -385,8 +386,7 @@ def run( for samples in tqdm( valid_loader, desc="valid", leave=False, disable=None ): - - # labelled + # data forwarding on the existing network images = samples[1].to( device=device, non_blocking=torch.cuda.is_available(), @@ -395,25 +395,16 @@ def run( device=device, non_blocking=torch.cuda.is_available(), ) - unlabelled_images = samples[4].to( - device=device, - non_blocking=torch.cuda.is_available(), - ) - # labelled outputs + masks = None + if len(samples) == 4: + masks = samples[-1].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) + outputs = model(images) - unlabelled_outputs = model(unlabelled_images) - # guessed unlabelled outputs - unlabelled_ground_truths = guess_labels( - unlabelled_images, model - ) - loss, ll, ul = criterion( - outputs, - ground_truths, - unlabelled_outputs, - unlabelled_ground_truths, - ramp_up_factor, - ) + loss = criterion(outputs, ground_truths, masks) valid_losses.update(loss) if checkpoint_period and (epoch % checkpoint_period == 0):