diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py
index 95fc1b7606261a9ad9de493bfe62f04dcf90e5f6..4615b76560ee760f3af18474458161d3bb34fe65 100644
--- a/bob/ip/binseg/engine/predictor.py
+++ b/bob/ip/binseg/engine/predictor.py
@@ -116,8 +116,8 @@ def run(model, data_loader, name, device, output_folder, overlayed_folder):
         the local name of this dataset (e.g. ``train``, or ``test``), to be
         used when saving measures files.
 
-    device : str
-        device to use ``cpu`` or ``cuda:0``
+    device : :py:class:`torch.device`
+        device to use
 
     output_folder : str
         folder where to store output prediction maps (HDF5 files) and model
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index 1c587de7bf42dbae21146e37e0a25ae8198333c6..d3717050f9c4ce3cd67c4c334a4e2d2230311fbf 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -100,8 +100,8 @@ def run(
         save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
         not save intermediary checkpoints
 
-    device : str
-        device to use ``'cpu'`` or ``cuda:0``
+    device : :py:class:`torch.device`
+        device to use
 
     arguments : dict
         start and end epochs
@@ -113,11 +113,10 @@ def run(
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
-    if device != "cpu":
+    if device.type == "cuda":
         # asserts we do have a GPU
         assert bool(gpu_constants()), (
-            f"Device set to '{device}', but cannot "
-            f"find a GPU (maybe nvidia-smi is not installed?)"
+            f"Device set to '{device}', but nvidia-smi is not installed"
         )
 
     os.makedirs(output_folder, exist_ok=True)
@@ -139,7 +138,7 @@ def run(
         shutil.move(static_logfile_name, backup)
     with open(static_logfile_name, "w", newline="") as f:
         logdata = cpu_constants()
-        if device != "cpu":
+        if device == "cuda":
             logdata += gpu_constants()
         logdata += (("model_size", n),)
         logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
@@ -166,7 +165,7 @@ def run(
     if valid_loader is not None:
         logfile_fields += ("validation_average_loss", "validation_median_loss")
     logfile_fields += tuple([k[0] for k in cpu_log()])
-    if device != "cpu":
+    if device.type == "cuda":
         logfile_fields += tuple([k[0] for k in gpu_log()])
 
     # the lowest validation loss obtained so far - this value is updated only
@@ -308,7 +307,7 @@ def run(
                     ("validation_median_loss", f"{valid_losses.median:.6f}"),
                 )
             logdata += cpu_log()
-            if device != "cpu":
+            if device.type == "cuda":
                 logdata += gpu_log()
 
             logwriter.writerow(dict(k for k in logdata))
diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py
index 091792acd02bd5629d94d1ccf64bed4bdcedaec6..afa00f950b674e8c579fe01dabb9d9520c883946 100644
--- a/bob/ip/binseg/script/binseg.py
+++ b/bob/ip/binseg/script/binseg.py
@@ -7,6 +7,7 @@ import os
 import re
 import sys
 import time
+import random
 import tempfile
 import urllib.request
 
@@ -15,12 +16,100 @@ import click
 from click_plugins import with_plugins
 from tqdm import tqdm
 
+import numpy
+import torch
+
 from bob.extension.scripts.click_helper import AliasedGroup
 
 import logging
 logger = logging.getLogger(__name__)
 
 
+def setup_pytorch_device(name):
+    """Sets-up the pytorch device to use
+
+
+    Parameters
+    ----------
+
+    name : str
+        The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on)
+
+
+    Returns
+    -------
+
+    device : :py:class:`torch.device`
+        The pytorch device to use, pre-configured (and checked)
+
+    """
+
+    if name.startswith("cuda"):
+        # In case one has multiple devices, we must first set the one
+        # we would like to use so pytorch can find it.
+        os.environ['CUDA_VISIBLE_DEVICES'] = name.split(":",1)[1]
+        if not torch.cuda.is_available():
+            raise RuntimeError(f"CUDA is not currently available, but " \
+                    f"you set device to '{name}'")
+        # Let pytorch auto-select from environment variable
+        return torch.device("cuda")
+
+    #cpu
+    return torch.device(name)
+
+
+def set_seeds(value, all_gpus):
+    """Sets up all relevant random seeds (numpy, python, cuda)
+
+    If running with multiple GPUs **at the same time**, set ``all_gpus`` to
+    ``True`` to force all GPU seeds to be initialized.
+
+    Reference: `PyTorch page for reproducibility
+    <https://pytorch.org/docs/stable/notes/randomness.html>`_.
+
+
+    Parameters
+    ----------
+
+    value : int
+        The random seed value to use
+
+    all_gpus : :py:class:`bool`, Optional
+        If set, then reset the seed on all GPUs available at once.  This is
+        normally **not** what you want if running on a single GPU
+
+    """
+
+    random.seed(value)
+    numpy.random.seed(value)
+    torch.manual_seed(value)
+    torch.cuda.manual_seed(value)  #noop if cuda not available
+
+    # set seeds for all gpus
+    if all_gpus:
+        torch.cuda.manual_seed_all(value)  #noop if cuda not available
+
+
+def set_reproducible_cuda():
+    """Turns-off all CUDA optimizations that would affect reproducibility
+
+    For full reproducibility, also ensure not to use multiple (parallel) data
+    lowers.  That is setup ``num_workers=0``.
+
+    Reference: `PyTorch page for reproducibility
+    <https://pytorch.org/docs/stable/notes/randomness.html>`_.
+
+
+    """
+
+    # ensure to use only optimization algos for cuda that are known to have
+    # a deterministic effect (not random)
+    torch.backends.cudnn.deterministic = True
+
+    # turns off any optimization tricks
+    torch.backends.cudnn.benchmark = False
+
+
 def escape_name(v):
     """Escapes a name so it contains filesystem friendly characters only
 
diff --git a/bob/ip/binseg/script/predict.py b/bob/ip/binseg/script/predict.py
index f87e8bcbd3afd4f130441a6e95287fd2666e5687..da458905b8b004d6ccbc3646e5280fef23f89ad7 100644
--- a/bob/ip/binseg/script/predict.py
+++ b/bob/ip/binseg/script/predict.py
@@ -17,7 +17,7 @@ from bob.extension.scripts.click_helper import (
 from ..engine.predictor import run
 from ..utils.checkpointer import Checkpointer
 
-from .binseg import download_to_tempfile
+from .binseg import download_to_tempfile, setup_pytorch_device
 
 import logging
 logger = logging.getLogger(__name__)
@@ -115,6 +115,8 @@ def predict(output_folder, model, dataset, batch_size, device, weight,
         overlayed, **kwargs):
     """Predicts vessel map (probabilities) on input images"""
 
+    device = setup_pytorch_device(device)
+
     dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
 
     if weight.startswith("http"):
diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py
index 5edcd7e708a486d410b87b751108b025204368a6..94b20dab5e3764cd009da1676645f65686c553b2 100644
--- a/bob/ip/binseg/script/train.py
+++ b/bob/ip/binseg/script/train.py
@@ -14,6 +14,7 @@ from bob.extension.scripts.click_helper import (
 )
 
 from ..utils.checkpointer import Checkpointer
+from .binseg import setup_pytorch_device, set_seeds
 
 import logging
 logger = logging.getLogger(__name__)
@@ -216,7 +217,9 @@ def train(
     abruptly.
     """
 
-    torch.manual_seed(seed)
+    device = setup_pytorch_device(device)
+
+    set_seeds(seed, all_gpus=False)
 
     use_dataset = dataset
     validation_dataset = None
diff --git a/doc/extras.inv b/doc/extras.inv
index f02c5b9fd164f673297c60f44cb6f75434a83d82..1fcd9a80e99a758b7a7bc41d2f2a9c98013f3264 100644
Binary files a/doc/extras.inv and b/doc/extras.inv differ
diff --git a/doc/extras.txt b/doc/extras.txt
index ee07f144b5f4c55a6b1fb3feabe117b64c787dc2..a285d06c679ed32f27a7498ca774b40ecb6b3972 100644
--- a/doc/extras.txt
+++ b/doc/extras.txt
@@ -2,6 +2,7 @@
 # Project: extras
 # Version: stable
 # The remainder of this file is compressed using zlib.
+torch.device py:class 1 https://pytorch.org/docs/master/tensor_attributes.html#torch.torch.device -
 torch.optim.optimizer.Optimizer py:class 1 https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer -
 torch.nn.Module py:class 1 https://pytorch.org/docs/stable/nn.html?highlight=module#torch.nn.Module -
 torch.nn.modules.module.Module py:class 1 https://pytorch.org/docs/stable/nn.html?highlight=module#torch.nn.Module -