diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index dcc50005a05e6a29af1e1c7a76f1ad9bb6cdbb8f..40566f21aad29a1ed588ff67938dd7cd506d5cac 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -16,12 +16,12 @@ from tqdm import tqdm from ..utils.measure import SmoothedValue from ..utils.summary import summary from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log +from .trainer import PYTORCH_GE_110, torch_evaluation import logging logger = logging.getLogger(__name__) -PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0" def sharpen(x, T): @@ -371,31 +371,34 @@ def run( # calculates the validation loss if necessary valid_losses = None if valid_loader is not None: - valid_losses = SmoothedValue(len(valid_loader)) - for samples in tqdm( - valid_loader, desc="valid", leave=False, disable=None - ): - - # labelled - images = samples[1].to(device) - ground_truths = samples[2].to(device) - unlabelled_images = samples[4].to(device) - # labelled outputs - outputs = model(images) - unlabelled_outputs = model(unlabelled_images) - # guessed unlabelled outputs - unlabelled_ground_truths = guess_labels( - unlabelled_images, model - ) - loss, ll, ul = criterion( - outputs, - ground_truths, - unlabelled_outputs, - unlabelled_ground_truths, - ramp_up_factor, - ) - - valid_losses.update(loss) + + with torch.no_grad(), torch_evaluation(model): + + valid_losses = SmoothedValue(len(valid_loader)) + for samples in tqdm( + valid_loader, desc="valid", leave=False, disable=None + ): + + # labelled + images = samples[1].to(device) + ground_truths = samples[2].to(device) + unlabelled_images = samples[4].to(device) + # labelled outputs + outputs = model(images) + unlabelled_outputs = model(unlabelled_images) + # guessed unlabelled outputs + unlabelled_ground_truths = guess_labels( + unlabelled_images, model + ) + loss, ll, ul = criterion( + outputs, + ground_truths, + unlabelled_outputs, + unlabelled_ground_truths, + ramp_up_factor, + ) + + valid_losses.update(loss) if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments) diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 7d1f841bd4cd6f98fad641cd3ccffb9d98284ee2..ed34fbe0226c749c3e94ad59b57e23dfeece8b5c 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -7,6 +7,7 @@ import csv import time import shutil import datetime +import contextlib import distutils.version import torch @@ -23,6 +24,34 @@ logger = logging.getLogger(__name__) PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0" +@contextlib.contextmanager +def torch_evaluation(model): + """Context manager to turn ON/OFF model evaluation + + This context manager will turn evaluation mode ON on entry and turn it OFF + when exiting the ``with`` statement block. + + + Parameters + ---------- + + model : :py:class:`torch.nn.Module` + Network (e.g. driu, hed, unet) + + + Yields + ------ + + model : :py:class:`torch.nn.Module` + Network (e.g. driu, hed, unet) + + """ + + model.eval() + yield model + model.train() + + def run( model, data_loader, @@ -203,21 +232,24 @@ def run( # calculates the validation loss if necessary valid_losses = None if valid_loader is not None: - valid_losses = SmoothedValue(len(valid_loader)) - for samples in tqdm( - valid_loader, desc="valid", leave=False, disable=None - ): - # data forwarding on the existing network - images = samples[1].to(device) - ground_truths = samples[2].to(device) - masks = None - if len(samples) == 4: - masks = samples[-1].to(device) - - outputs = model(images) - - loss = criterion(outputs, ground_truths, masks) - valid_losses.update(loss) + + with torch.no_grad(), torch_evaluation(model): + + valid_losses = SmoothedValue(len(valid_loader)) + for samples in tqdm( + valid_loader, desc="valid", leave=False, disable=None + ): + # data forwarding on the existing network + images = samples[1].to(device) + ground_truths = samples[2].to(device) + masks = None + if len(samples) == 4: + masks = samples[-1].to(device) + + outputs = model(images) + + loss = criterion(outputs, ground_truths, masks) + valid_losses.update(loss) if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments)