diff --git a/bob/pad/face/deep_pix_bis.py b/bob/pad/face/deep_pix_bis.py
index c6e976c82e0e6c9013fcc84082943069a83e4c7e..38ef6029dead068c776c322aa644acf89df38b30 100644
--- a/bob/pad/face/deep_pix_bis.py
+++ b/bob/pad/face/deep_pix_bis.py
@@ -178,9 +178,7 @@ class DeepPixBisClassifier(BaseEstimator, ClassifierMixin):
                     self.scoring_method
                 )
             )
-        self.device = torch.device(
-            device or "cuda" if torch.cuda.is_available() else "cpu"
-        )
+        self.device = device
         self.threshold = threshold
 
         logger.debug(
@@ -275,5 +273,9 @@ class DeepPixBisClassifier(BaseEstimator, ClassifierMixin):
         return {"requires_fit": False}
 
     def place_model_on_device(self):
+        if self.device is None:
+            self.device = torch.device(
+                "cuda" if torch.cuda.is_available() else "cpu"
+            )
         if self.model is not None:
             self.model.to(self.device)