From e3e460a2bc2507602558abaa47d5ccf0801f25d4 Mon Sep 17 00:00:00 2001 From: Manuel Guenther <guenther@ifi.uzh.ch> Date: Thu, 24 Nov 2022 17:05:39 +0100 Subject: [PATCH] Passed device to oxford_vgg_resnets --- bob/bio/face/embeddings/pytorch.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/bob/bio/face/embeddings/pytorch.py b/bob/bio/face/embeddings/pytorch.py index 790d2641..5bf79723 100644 --- a/bob/bio/face/embeddings/pytorch.py +++ b/bob/bio/face/embeddings/pytorch.py @@ -1210,7 +1210,11 @@ def afffe_baseline( def oxford_vgg2_resnets( - model_name, annotation_type, fixed_positions=None, memory_demanding=False + model_name, + annotation_type, + fixed_positions=None, + memory_demanding=False, + device=torch.device("cpu"), ): """ Get the pipeline for the resnet based models from Oxford. @@ -1244,7 +1248,9 @@ def oxford_vgg2_resnets( transformer = embedding_transformer( cropped_image_size=cropped_image_size, embedding=OxfordVGG2Resnets( - model_name=model_name, memory_demanding=memory_demanding + model_name=model_name, + memory_demanding=memory_demanding + device=device ), cropped_positions=cropped_positions, fixed_positions=fixed_positions, -- GitLab