From e3e460a2bc2507602558abaa47d5ccf0801f25d4 Mon Sep 17 00:00:00 2001
From: Manuel Guenther <guenther@ifi.uzh.ch>
Date: Thu, 24 Nov 2022 17:05:39 +0100
Subject: [PATCH] Passed device to oxford_vgg_resnets

---
 bob/bio/face/embeddings/pytorch.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/bob/bio/face/embeddings/pytorch.py b/bob/bio/face/embeddings/pytorch.py
index 790d2641..5bf79723 100644
--- a/bob/bio/face/embeddings/pytorch.py
+++ b/bob/bio/face/embeddings/pytorch.py
@@ -1210,7 +1210,11 @@ def afffe_baseline(
 
 
 def oxford_vgg2_resnets(
-    model_name, annotation_type, fixed_positions=None, memory_demanding=False
+    model_name,
+    annotation_type,
+    fixed_positions=None,
+    memory_demanding=False,
+    device=torch.device("cpu"),
 ):
     """
     Get the pipeline for the resnet based models from Oxford.
@@ -1244,7 +1248,9 @@ def oxford_vgg2_resnets(
     transformer = embedding_transformer(
         cropped_image_size=cropped_image_size,
         embedding=OxfordVGG2Resnets(
-            model_name=model_name, memory_demanding=memory_demanding
+            model_name=model_name,
+            memory_demanding=memory_demanding
+            device=device
         ),
         cropped_positions=cropped_positions,
         fixed_positions=fixed_positions,
-- 
GitLab