diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py
index de78e69f5ba071d017babd5f7b3654244e64bb80..89d24d089d7f4eeece5d3974fe5e11ab90270eb0 100644
--- a/bob/ip/binseg/engine/predictor.py
+++ b/bob/ip/binseg/engine/predictor.py
@@ -17,6 +17,7 @@ import h5py
 from ..data.utils import overlayed_image
 
 import logging
+
 logger = logging.getLogger(__name__)
 
 
@@ -41,10 +42,12 @@ def _save_hdf5(stem, prob, output_folder):
     fullpath = os.path.join(output_folder, f"{stem}.hdf5")
     tqdm.write(f"Saving {fullpath}...")
     os.makedirs(os.path.dirname(fullpath), exist_ok=True)
-    with h5py.File(fullpath, 'w') as f:
+    with h5py.File(fullpath, "w") as f:
         data = prob.cpu().squeeze(0).numpy()
-        f.create_dataset("array", data=data, compression="gzip",
-                compression_opts=9)
+        f.create_dataset(
+            "array", data=data, compression="gzip", compression_opts=9
+        )
+
 
 def _save_image(stem, extension, data, output_folder):
     """Saves a PIL image into a file
@@ -95,7 +98,7 @@ def _save_overlayed_png(stem, image, prob, output_folder):
 
     image = VF.to_pil_image(image)
     prob = VF.to_pil_image(prob.cpu())
-    _save_image(stem, '.png', overlayed_image(image, prob), output_folder)
+    _save_image(stem, ".png", overlayed_image(image, prob), output_folder)
 
 
 def run(model, data_loader, device, output_folder, overlayed_folder):
@@ -136,11 +139,13 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
     len_samples = []
 
     for samples in tqdm(
-            data_loader, desc="batches", leave=False, disable=None,
-            ):
+        data_loader, desc="batches", leave=False, disable=None,
+    ):
 
         names = samples[0]
-        images = samples[1].to(device)
+        images = samples[1].to(
+            device=device, non_blocking=torch.cuda.is_available()
+        )
 
         with torch.no_grad():
 
@@ -170,5 +175,7 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
     average_batch_time = numpy.mean(times)
     logger.info(f"Average batch time: {average_batch_time:g}s")
 
-    average_image_time = numpy.sum(numpy.array(times) * len_samples) / float(sum(len_samples))
+    average_image_time = numpy.sum(numpy.array(times) * len_samples) / float(
+        sum(len_samples)
+    )
     logger.info(f"Average image time: {average_image_time:g}s")