Skip to content
Snippets Groups Projects
Commit 0f413ac4 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.ssltrainer] Validation does not compute SSL loss, just standard model performance

parent 687f8b24
No related branches found
No related tags found
No related merge requests found
......@@ -376,6 +376,7 @@ def run(
scheduler.step()
# calculates the validation loss if necessary
# note: validation does not comprise "unlabelled" losses
valid_losses = None
if valid_loader is not None:
......@@ -385,8 +386,7 @@ def run(
for samples in tqdm(
valid_loader, desc="valid", leave=False, disable=None
):
# labelled
# data forwarding on the existing network
images = samples[1].to(
device=device,
non_blocking=torch.cuda.is_available(),
......@@ -395,25 +395,16 @@ def run(
device=device,
non_blocking=torch.cuda.is_available(),
)
unlabelled_images = samples[4].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
# labelled outputs
masks = None
if len(samples) == 4:
masks = samples[-1].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
outputs = model(images)
unlabelled_outputs = model(unlabelled_images)
# guessed unlabelled outputs
unlabelled_ground_truths = guess_labels(
unlabelled_images, model
)
loss, ll, ul = criterion(
outputs,
ground_truths,
unlabelled_outputs,
unlabelled_ground_truths,
ramp_up_factor,
)
loss = criterion(outputs, ground_truths, masks)
valid_losses.update(loss)
if checkpoint_period and (epoch % checkpoint_period == 0):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment