Skip to content
Snippets Groups Projects

Resolve "Pytorch device is not followed in embeddings"

Merged Manuel Günther requested to merge 68-pytorch-device-is-not-followed-in-embeddings into master
1 file
+ 3
10
Compare changes
  • Side-by-side
  • Inline
@@ -56,7 +56,7 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
self.model = None
self.preprocessor = preprocessor
self.memory_demanding = memory_demanding
self.device = device
self.device = torch.device(device or "cuda" if torch.cuda.is_available() else "cpu")
def transform(self, X):
"""__call__(image) -> feature
@@ -110,16 +110,9 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
def place_model_on_device(self, device=None):
import torch
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
def place_model_on_device(self):
if self.model is not None:
self.model.to(device)
self.model.to(self.device)
class AFFFE_2021(PyTorchModel):
Loading