diff --git a/evaluation_pipeline.py b/evaluation_pipeline.py index dbced92329f91c19135014b2c589568137053253..dc47eadf1d8e5aa6a4d54aee0b2a5d8d5f39b4f9 100644 --- a/evaluation_pipeline.py +++ b/evaluation_pipeline.py @@ -82,14 +82,14 @@ from bob.bio.invert.wrappers import get_invert_pipeline import os,sys sys.path.append(os.getcwd()) # import src sys.path.append(f"{args.path_eg3d_repo}/eg3d") # import eg3d files -if args.attack='GaFaR': +if args.attack=='GaFaR': from transformers import GaFaR_InversionTransformer as InversionTransformer inv_transformer = InversionTransformer(checkpoint=args.checkpoint, eg3d_checkpoint=args.path_eg3d_checkpoint) -elif args.attack='GaFaR_CO': +elif args.attack=='GaFaR_CO': sys.path.append('./InsightFace-PyTorch') # import detect_align from transformers import GaFaR_CO_InversionTransformer as InversionTransformer inv_transformer = InversionTransformer(checkpoint=args.checkpoint, eg3d_checkpoint=args.path_eg3d_checkpoint, FR_system=args.FR_system) -elif args.attack='GaFaR_GS': +elif args.attack=='GaFaR_GS': sys.path.append('./InsightFace-PyTorch') # import detect_align from transformers import GaFaR_GS_InversionTransformer as InversionTransformer inv_transformer = InversionTransformer(checkpoint=args.checkpoint, eg3d_checkpoint=args.path_eg3d_checkpoint, FR_system=args.FR_system) diff --git a/transformers.py b/transformers.py index 9eff016ea9528282f3db10e5f4f2fefc7fe9c43c..a70c0b9a2e48e171b90720027a7987d9798e0d76 100644 --- a/transformers.py +++ b/transformers.py @@ -167,7 +167,7 @@ class GaFaR_CO_InversionTransformer(TransformerMixin, BaseEstimator): from detect_align import detectLM_align self.align = detectLM_align(detector_path= './InsightFace-PyTorch/retinaface/weights/mobilenet0.25_Final.pth', device=self.device) - self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, self.device) + self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, device=self.device) _ = self.FaceRecognition_transformer.transform(torch.zeros([1,3,112,112]).to(self.device))#._load_model(), eval() @@ -320,7 +320,7 @@ class GaFaR_GS_InversionTransformer(TransformerMixin, BaseEstimator): from detect_align import detectLM_align self.align = detectLM_align(detector_path= './InsightFace-PyTorch/retinaface/weights/mobilenet0.25_Final.pth', device=self.device) - self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, self.device) + self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, device=self.device) _ = self.FaceRecognition_transformer.transform(torch.zeros([1,3,112,112]).to(self.device))#._load_model(), eval()