Skip to content
Snippets Groups Projects

fix device for input tensor (PyTorch)

Merged Hatef OTROSHI requested to merge fix-pytorch into master
1 file
+ 1
1
Compare changes
  • Side-by-side
  • Inline
@@ -81,7 +81,7 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
def _transform(X):
with torch.no_grad():
return self.model(X).cpu().detach().numpy()
return self.model(X.to(self.device)).cpu().detach().numpy()
if self.memory_demanding:
return np.array([_transform(x[None, ...]) for x in X])
Loading