From 4909eba9ee173c325cb921dba852edaa1f7f0c17 Mon Sep 17 00:00:00 2001
From: Hatef OTROSHI <hatef.otroshi@idiap.ch>
Date: Tue, 30 Apr 2024 09:07:00 +0200
Subject: [PATCH] fix

---
 evaluation_pipeline.py | 6 +++---
 transformers.py        | 4 ++--
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/evaluation_pipeline.py b/evaluation_pipeline.py
index dbced92..dc47ead 100644
--- a/evaluation_pipeline.py
+++ b/evaluation_pipeline.py
@@ -82,14 +82,14 @@ from bob.bio.invert.wrappers import get_invert_pipeline
 import os,sys
 sys.path.append(os.getcwd()) # import src
 sys.path.append(f"{args.path_eg3d_repo}/eg3d") # import eg3d files
-if args.attack='GaFaR':
+if args.attack=='GaFaR':
     from transformers import GaFaR_InversionTransformer as InversionTransformer
     inv_transformer = InversionTransformer(checkpoint=args.checkpoint, eg3d_checkpoint=args.path_eg3d_checkpoint)
-elif args.attack='GaFaR_CO':
+elif args.attack=='GaFaR_CO':
     sys.path.append('./InsightFace-PyTorch') # import detect_align
     from transformers import GaFaR_CO_InversionTransformer as InversionTransformer
     inv_transformer = InversionTransformer(checkpoint=args.checkpoint, eg3d_checkpoint=args.path_eg3d_checkpoint, FR_system=args.FR_system)
-elif args.attack='GaFaR_GS':
+elif args.attack=='GaFaR_GS':
     sys.path.append('./InsightFace-PyTorch') # import detect_align
     from transformers import GaFaR_GS_InversionTransformer as InversionTransformer
     inv_transformer = InversionTransformer(checkpoint=args.checkpoint, eg3d_checkpoint=args.path_eg3d_checkpoint, FR_system=args.FR_system)
diff --git a/transformers.py b/transformers.py
index 9eff016..a70c0b9 100644
--- a/transformers.py
+++ b/transformers.py
@@ -167,7 +167,7 @@ class GaFaR_CO_InversionTransformer(TransformerMixin, BaseEstimator):
         from detect_align import detectLM_align
         self.align = detectLM_align(detector_path= './InsightFace-PyTorch/retinaface/weights/mobilenet0.25_Final.pth', device=self.device)
 
-        self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, self.device)
+        self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, device=self.device)
         _ = self.FaceRecognition_transformer.transform(torch.zeros([1,3,112,112]).to(self.device))#._load_model(), eval()
 
         
@@ -320,7 +320,7 @@ class GaFaR_GS_InversionTransformer(TransformerMixin, BaseEstimator):
         from detect_align import detectLM_align
         self.align = detectLM_align(detector_path= './InsightFace-PyTorch/retinaface/weights/mobilenet0.25_Final.pth', device=self.device)
 
-        self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, self.device)
+        self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, device=self.device)
         _ = self.FaceRecognition_transformer.transform(torch.zeros([1,3,112,112]).to(self.device))#._load_model(), eval()
 
         
-- 
GitLab