diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 98c17e1b56c7e0732870f36d005dbf395446da74..6d378268a2c5aed5e54b34ea88e86c69afe4763b 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -23,7 +23,6 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-
 def sharpen(x, T):
     temp = x ** (1 / T)
     return temp / temp.sum(dim=1, keepdim=True)
@@ -300,7 +299,7 @@ def run(
         if arguments["epoch"] == 0:
             logwriter.writeheader()
 
-        model.train()  #set training mode
+        model.train()  # set training mode
 
         model.to(device)  # set/cast parameters to device
         for state in optimizer.state.values():
@@ -336,9 +335,15 @@ def run(
                 # data forwarding on the existing network
 
                 # labelled
-                images = samples[1].to(device)
-                ground_truths = samples[2].to(device)
-                unlabelled_images = samples[4].to(device)
+                images = samples[1].to(
+                    device=device, non_blocking=torch.cuda.is_available()
+                )
+                ground_truths = samples[2].to(
+                    device=device, non_blocking=torch.cuda.is_available()
+                )
+                unlabelled_images = samples[4].to(
+                    device=device, non_blocking=torch.cuda.is_available()
+                )
                 # labelled outputs
                 outputs = model(images)
                 unlabelled_outputs = model(unlabelled_images)
@@ -382,9 +387,18 @@ def run(
                     ):
 
                         # labelled
-                        images = samples[1].to(device)
-                        ground_truths = samples[2].to(device)
-                        unlabelled_images = samples[4].to(device)
+                        images = samples[1].to(
+                            device=device,
+                            non_blocking=torch.cuda.is_available(),
+                        )
+                        ground_truths = samples[2].to(
+                            device=device,
+                            non_blocking=torch.cuda.is_available(),
+                        )
+                        unlabelled_images = samples[4].to(
+                            device=device,
+                            non_blocking=torch.cuda.is_available(),
+                        )
                         # labelled outputs
                         outputs = model(images)
                         unlabelled_outputs = model(unlabelled_images)
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index 4c4e558e66bdfdba99d066ca9060d7d4a41678df..e0299cc1f1a29ccbb1fccc8c45146dd15eeae1e3 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -179,7 +179,7 @@ def run(
         if arguments["epoch"] == 0:
             logwriter.writeheader()
 
-        model.train()  #set training mode
+        model.train()  # set training mode
 
         model.to(device)  # set/cast parameters to device
         for state in optimizer.state.values():
@@ -211,11 +211,17 @@ def run(
             ):
 
                 # data forwarding on the existing network
-                images = samples[1].to(device)
-                ground_truths = samples[2].to(device)
+                images = samples[1].to(
+                    device=device, non_blocking=torch.cuda.is_available()
+                )
+                ground_truths = samples[2].to(
+                    device=device, non_blocking=torch.cuda.is_available()
+                )
                 masks = None
                 if len(samples) == 4:
-                    masks = samples[-1].to(device)
+                    masks = samples[-1].to(
+                        device=device, non_blocking=torch.cuda.is_available()
+                    )
 
                 outputs = model(images)
 
@@ -242,11 +248,20 @@ def run(
                         valid_loader, desc="valid", leave=False, disable=None
                     ):
                         # data forwarding on the existing network
-                        images = samples[1].to(device)
-                        ground_truths = samples[2].to(device)
+                        images = samples[1].to(
+                            device=device,
+                            non_blocking=torch.cuda.is_available(),
+                        )
+                        ground_truths = samples[2].to(
+                            device=device,
+                            non_blocking=torch.cuda.is_available(),
+                        )
                         masks = None
                         if len(samples) == 4:
-                            masks = samples[-1].to(device)
+                            masks = samples[-1].to(
+                                device=device,
+                                non_blocking=torch.cuda.is_available(),
+                            )
 
                         outputs = model(images)