Skip to content
Snippets Groups Projects
Commit 944573ab authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'fix_grid' into 'master'

Fix grid

See merge request !129
parents caca9f61 82793978
Branches
Tags
1 merge request!129Fix grid
Pipeline #62059 passed
......@@ -178,9 +178,7 @@ class DeepPixBisClassifier(BaseEstimator, ClassifierMixin):
self.scoring_method
)
)
self.device = torch.device(
device or "cuda" if torch.cuda.is_available() else "cpu"
)
self.device = device
self.threshold = threshold
logger.debug(
......@@ -275,5 +273,9 @@ class DeepPixBisClassifier(BaseEstimator, ClassifierMixin):
return {"requires_fit": False}
def place_model_on_device(self):
if self.device is None:
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
if self.model is not None:
self.model.to(self.device)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment