diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index dcc50005a05e6a29af1e1c7a76f1ad9bb6cdbb8f..40566f21aad29a1ed588ff67938dd7cd506d5cac 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -16,12 +16,12 @@ from tqdm import tqdm
 from ..utils.measure import SmoothedValue
 from ..utils.summary import summary
 from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log
+from .trainer import PYTORCH_GE_110, torch_evaluation
 
 import logging
 
 logger = logging.getLogger(__name__)
 
-PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0"
 
 
 def sharpen(x, T):
@@ -371,31 +371,34 @@ def run(
             # calculates the validation loss if necessary
             valid_losses = None
             if valid_loader is not None:
-                valid_losses = SmoothedValue(len(valid_loader))
-                for samples in tqdm(
-                    valid_loader, desc="valid", leave=False, disable=None
-                ):
-
-                    # labelled
-                    images = samples[1].to(device)
-                    ground_truths = samples[2].to(device)
-                    unlabelled_images = samples[4].to(device)
-                    # labelled outputs
-                    outputs = model(images)
-                    unlabelled_outputs = model(unlabelled_images)
-                    # guessed unlabelled outputs
-                    unlabelled_ground_truths = guess_labels(
-                        unlabelled_images, model
-                    )
-                    loss, ll, ul = criterion(
-                        outputs,
-                        ground_truths,
-                        unlabelled_outputs,
-                        unlabelled_ground_truths,
-                        ramp_up_factor,
-                    )
-
-                    valid_losses.update(loss)
+
+                with torch.no_grad(), torch_evaluation(model):
+
+                    valid_losses = SmoothedValue(len(valid_loader))
+                    for samples in tqdm(
+                        valid_loader, desc="valid", leave=False, disable=None
+                    ):
+
+                        # labelled
+                        images = samples[1].to(device)
+                        ground_truths = samples[2].to(device)
+                        unlabelled_images = samples[4].to(device)
+                        # labelled outputs
+                        outputs = model(images)
+                        unlabelled_outputs = model(unlabelled_images)
+                        # guessed unlabelled outputs
+                        unlabelled_ground_truths = guess_labels(
+                            unlabelled_images, model
+                        )
+                        loss, ll, ul = criterion(
+                            outputs,
+                            ground_truths,
+                            unlabelled_outputs,
+                            unlabelled_ground_truths,
+                            ramp_up_factor,
+                        )
+
+                        valid_losses.update(loss)
 
             if checkpoint_period and (epoch % checkpoint_period == 0):
                 checkpointer.save(f"model_{epoch:03d}", **arguments)
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index 7d1f841bd4cd6f98fad641cd3ccffb9d98284ee2..ed34fbe0226c749c3e94ad59b57e23dfeece8b5c 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -7,6 +7,7 @@ import csv
 import time
 import shutil
 import datetime
+import contextlib
 import distutils.version
 
 import torch
@@ -23,6 +24,34 @@ logger = logging.getLogger(__name__)
 PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0"
 
 
+@contextlib.contextmanager
+def torch_evaluation(model):
+    """Context manager to turn ON/OFF model evaluation
+
+    This context manager will turn evaluation mode ON on entry and turn it OFF
+    when exiting the ``with`` statement block.
+
+
+    Parameters
+    ----------
+
+    model : :py:class:`torch.nn.Module`
+        Network (e.g. driu, hed, unet)
+
+
+    Yields
+    ------
+
+    model : :py:class:`torch.nn.Module`
+        Network (e.g. driu, hed, unet)
+
+    """
+
+    model.eval()
+    yield model
+    model.train()
+
+
 def run(
     model,
     data_loader,
@@ -203,21 +232,24 @@ def run(
             # calculates the validation loss if necessary
             valid_losses = None
             if valid_loader is not None:
-                valid_losses = SmoothedValue(len(valid_loader))
-                for samples in tqdm(
-                    valid_loader, desc="valid", leave=False, disable=None
-                ):
-                    # data forwarding on the existing network
-                    images = samples[1].to(device)
-                    ground_truths = samples[2].to(device)
-                    masks = None
-                    if len(samples) == 4:
-                        masks = samples[-1].to(device)
-
-                    outputs = model(images)
-
-                    loss = criterion(outputs, ground_truths, masks)
-                    valid_losses.update(loss)
+
+                with torch.no_grad(), torch_evaluation(model):
+
+                    valid_losses = SmoothedValue(len(valid_loader))
+                    for samples in tqdm(
+                        valid_loader, desc="valid", leave=False, disable=None
+                    ):
+                        # data forwarding on the existing network
+                        images = samples[1].to(device)
+                        ground_truths = samples[2].to(device)
+                        masks = None
+                        if len(samples) == 4:
+                            masks = samples[-1].to(device)
+
+                        outputs = model(images)
+
+                        loss = criterion(outputs, ground_truths, masks)
+                        valid_losses.update(loss)
 
             if checkpoint_period and (epoch % checkpoint_period == 0):
                 checkpointer.save(f"model_{epoch:03d}", **arguments)