Skip to content
Snippets Groups Projects

Fix grid

1 file
+ 5
3
Compare changes
  • Side-by-side
  • Inline
+ 5
3
@@ -178,9 +178,7 @@ class DeepPixBisClassifier(BaseEstimator, ClassifierMixin):
self.scoring_method
)
)
self.device = torch.device(
device or "cuda" if torch.cuda.is_available() else "cpu"
)
self.device = device
self.threshold = threshold
logger.debug(
@@ -275,5 +273,9 @@ class DeepPixBisClassifier(BaseEstimator, ClassifierMixin):
return {"requires_fit": False}
def place_model_on_device(self):
if self.device is None:
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
if self.model is not None:
self.model.to(self.device)
Loading