diff --git a/bob/bio/face/embeddings/pytorch.py b/bob/bio/face/embeddings/pytorch.py index 790d264157481f93fdbcaecd530664d7298abb9d..5bf7972311d82790ba40435112e071bd59ada799 100644 --- a/bob/bio/face/embeddings/pytorch.py +++ b/bob/bio/face/embeddings/pytorch.py @@ -1210,7 +1210,11 @@ def afffe_baseline( def oxford_vgg2_resnets( - model_name, annotation_type, fixed_positions=None, memory_demanding=False + model_name, + annotation_type, + fixed_positions=None, + memory_demanding=False, + device=torch.device("cpu"), ): """ Get the pipeline for the resnet based models from Oxford. @@ -1244,7 +1248,9 @@ def oxford_vgg2_resnets( transformer = embedding_transformer( cropped_image_size=cropped_image_size, embedding=OxfordVGG2Resnets( - model_name=model_name, memory_demanding=memory_demanding + model_name=model_name, + memory_demanding=memory_demanding + device=device ), cropped_positions=cropped_positions, fixed_positions=fixed_positions,