diff --git a/bob/pad/face/deep_pix_bis.py b/bob/pad/face/deep_pix_bis.py index c6e976c82e0e6c9013fcc84082943069a83e4c7e..38ef6029dead068c776c322aa644acf89df38b30 100644 --- a/bob/pad/face/deep_pix_bis.py +++ b/bob/pad/face/deep_pix_bis.py @@ -178,9 +178,7 @@ class DeepPixBisClassifier(BaseEstimator, ClassifierMixin): self.scoring_method ) ) - self.device = torch.device( - device or "cuda" if torch.cuda.is_available() else "cpu" - ) + self.device = device self.threshold = threshold logger.debug( @@ -275,5 +273,9 @@ class DeepPixBisClassifier(BaseEstimator, ClassifierMixin): return {"requires_fit": False} def place_model_on_device(self): + if self.device is None: + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) if self.model is not None: self.model.to(self.device)