From aaab33de0b7838e2325ebad11fb886ce43731f64 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 13 May 2020 11:29:01 +0200 Subject: [PATCH] [engine.*trainer] Optimize validation during training with torch.no_grad() and model.eval() --- bob/ip/binseg/engine/ssltrainer.py | 55 +++++++++++++------------- bob/ip/binseg/engine/trainer.py | 62 ++++++++++++++++++++++-------- 2 files changed, 76 insertions(+), 41 deletions(-) diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index dcc50005..40566f21 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -16,12 +16,12 @@ from tqdm import tqdm from ..utils.measure import SmoothedValue from ..utils.summary import summary from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log +from .trainer import PYTORCH_GE_110, torch_evaluation import logging logger = logging.getLogger(__name__) -PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0" def sharpen(x, T): @@ -371,31 +371,34 @@ def run( # calculates the validation loss if necessary valid_losses = None if valid_loader is not None: - valid_losses = SmoothedValue(len(valid_loader)) - for samples in tqdm( - valid_loader, desc="valid", leave=False, disable=None - ): - - # labelled - images = samples[1].to(device) - ground_truths = samples[2].to(device) - unlabelled_images = samples[4].to(device) - # labelled outputs - outputs = model(images) - unlabelled_outputs = model(unlabelled_images) - # guessed unlabelled outputs - unlabelled_ground_truths = guess_labels( - unlabelled_images, model - ) - loss, ll, ul = criterion( - outputs, - ground_truths, - unlabelled_outputs, - unlabelled_ground_truths, - ramp_up_factor, - ) - - valid_losses.update(loss) + + with torch.no_grad(), torch_evaluation(model): + + valid_losses = SmoothedValue(len(valid_loader)) + for samples in tqdm( + valid_loader, desc="valid", leave=False, disable=None + ): + + # labelled + images = samples[1].to(device) + ground_truths = samples[2].to(device) + unlabelled_images = samples[4].to(device) + # labelled outputs + outputs = model(images) + unlabelled_outputs = model(unlabelled_images) + # guessed unlabelled outputs + unlabelled_ground_truths = guess_labels( + unlabelled_images, model + ) + loss, ll, ul = criterion( + outputs, + ground_truths, + unlabelled_outputs, + unlabelled_ground_truths, + ramp_up_factor, + ) + + valid_losses.update(loss) if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments) diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 7d1f841b..ed34fbe0 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -7,6 +7,7 @@ import csv import time import shutil import datetime +import contextlib import distutils.version import torch @@ -23,6 +24,34 @@ logger = logging.getLogger(__name__) PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0" +@contextlib.contextmanager +def torch_evaluation(model): + """Context manager to turn ON/OFF model evaluation + + This context manager will turn evaluation mode ON on entry and turn it OFF + when exiting the ``with`` statement block. + + + Parameters + ---------- + + model : :py:class:`torch.nn.Module` + Network (e.g. driu, hed, unet) + + + Yields + ------ + + model : :py:class:`torch.nn.Module` + Network (e.g. driu, hed, unet) + + """ + + model.eval() + yield model + model.train() + + def run( model, data_loader, @@ -203,21 +232,24 @@ def run( # calculates the validation loss if necessary valid_losses = None if valid_loader is not None: - valid_losses = SmoothedValue(len(valid_loader)) - for samples in tqdm( - valid_loader, desc="valid", leave=False, disable=None - ): - # data forwarding on the existing network - images = samples[1].to(device) - ground_truths = samples[2].to(device) - masks = None - if len(samples) == 4: - masks = samples[-1].to(device) - - outputs = model(images) - - loss = criterion(outputs, ground_truths, masks) - valid_losses.update(loss) + + with torch.no_grad(), torch_evaluation(model): + + valid_losses = SmoothedValue(len(valid_loader)) + for samples in tqdm( + valid_loader, desc="valid", leave=False, disable=None + ): + # data forwarding on the existing network + images = samples[1].to(device) + ground_truths = samples[2].to(device) + masks = None + if len(samples) == 4: + masks = samples[-1].to(device) + + outputs = model(images) + + loss = criterion(outputs, ground_truths, masks) + valid_losses.update(loss) if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments) -- GitLab