Skip to content
Snippets Groups Projects
Commit e0ffc354 authored by Hatef OTROSHI's avatar Hatef OTROSHI
Browse files

+ train and eval

parent 173844f7
Branches
No related tags found
No related merge requests found
import bob.bio.base
impostors, genuines = bob.bio.base.score.load.split_csv_scores('results/scores-dev.csv')
_, invertes = bob.bio.base.score.load.split_csv_scores('results/scores_inversion-dev.csv')
import numpy as np
def fmr_fnmr(neg, pos, threshold):
fmr, fnmr = 0, 0
fmr = np.mean(np.where(np.array(neg) >= threshold,1,0))
fnmr = np.mean(np.where(np.array(pos) < threshold,1,0))
return fmr, fnmr
from bob.measure import far_threshold
for FMR in [1e-2,1e-3]:
threshold = far_threshold(impostors, genuines, far_value=FMR)
fmr,fnmr = fmr_fnmr(impostors, genuines, threshold)
_,uSAR = fmr_fnmr(impostors, invertes, threshold)
x = fmr
SAR = 1-uSAR
TMR = 1-fnmr
print(f'FMR: {FMR} \t threshold: {threshold} \t TMR: {TMR}, SAR: {SAR}')
\ No newline at end of file
import argparse
parser = argparse.ArgumentParser(description='Vulnerability evaluation of face reocgnition system against template inversion attack')
parser.add_argument('--FR_system', metavar='<FR_system>', type= str, default='ArcFace',
help='ArcFace/ElasticFace (FR system from whose database the templates are leaked)')
parser.add_argument('--FR_target', metavar='<FR_target>', type= str, default='ArcFace',
help='ArcFace/ElasticFace')
parser.add_argument('--dataset', metavar='<dataset>', type= str, default='MOBIO',
help='MOBIO/LFW')
parser.add_argument('--attack', metavar='<attack_method>', type= str, default='GaFaR',
help='GaFaR/GaFaR_CO/GaFaR_GS')
parser.add_argument('--checkpoint', metavar='<checkpoint>', type= str, default='./training_files/models/new_mapping_15.pth',
help='checkpoint of the new mapping network')
parser.add_argument('--path_eg3d_repo', metavar='<path_eg3d_repo>', type= str, default='./eg3d',
help='./eg3d')
parser.add_argument('--path_eg3d_checkpoint', metavar='<path_eg3d_checkpoint>', type= str, default='./ffhq512-128.pkl',
help='./ffhq512-128.pkl`')
args = parser.parse_args()
import os, sys
# ================== dataset ======================
if args.dataset=='MOBIO':
from bob.bio.face.database import MobioDatabase
protocol = "mobile0-male-female"
database = MobioDatabase(protocol=protocol)
elif args.dataset=='LFW':
from bob.bio.face.config.database.lfw_view2 import database
else:
print(f"[eval pipeline] {args.dataset} dataset is not defined!")
# ================== Transformers ==================
if args.FR_system=="ArcFace":
from bob.bio.face.embeddings.pytorch import iresnet100 as get_pipeline_database
elif args.FR_system=="ElasticFace":
from bob.bio.face.embeddings.pytorch import iresnet100_elastic as get_pipeline_database
elif args.FR_system=='AttentionNet92':
from bob.bio.facexzoo.transformers.pytorch import AttentionNet92 as get_pipeline_database
elif args.FR_system== 'HRNet':
from bob.bio.facexzoo.transformers.pytorch import HRNet as get_pipeline_database
elif args.FR_system== 'RepVGG_B1':
from bob.bio.facexzoo.transformers.pytorch import RepVGG_B1 as get_pipeline_database
elif args.FR_system== 'SwinTransformer_S':
from bob.bio.facexzoo.transformers.pytorch import SwinTransformer_S as get_pipeline_database
else:
print(f"[eval pipeline] {args.FR_system} is not defined!")
pipeline = get_pipeline_database(
database.annotation_type,
fixed_positions=database.fixed_positions,
memory_demanding=database.memory_demanding,
)
FR_transformer_database = pipeline.transformer
if args.FR_target=="ArcFace":
from bob.bio.face.embeddings.pytorch import iresnet100 as get_pipeline_target
elif args.FR_target=="ElasticFace":
from bob.bio.face.embeddings.pytorch import iresnet100_elastic as get_pipeline_target
elif args.FR_target=='AttentionNet92':
from bob.bio.facexzoo.transformers.pytorch import AttentionNet92 as get_pipeline_target
elif args.FR_target== 'HRNet':
from bob.bio.facexzoo.transformers.pytorch import HRNet as get_pipeline_target
elif args.FR_target== 'RepVGG_B1':
from bob.bio.facexzoo.transformers.pytorch import RepVGG_B1 as get_pipeline_target
elif args.FR_target== 'SwinTransformer_S':
from bob.bio.facexzoo.transformers.pytorch import SwinTransformer_S as get_pipeline_target
else:
print(f"[eval pipeline] {args.FR_target} is not defined!")
pipeline = get_pipeline_target(
database.annotation_type,
fixed_positions=database.fixed_positions,
memory_demanding=database.memory_demanding,
)
FR_transformer_target = pipeline.transformer
# ================== Inversion Transformer ===========
from bob.pipelines import wrap, CheckpointWrapper, SampleWrapper
from bob.bio.invert.wrappers import get_invert_pipeline
import os,sys
sys.path.append(os.getcwd()) # import src
sys.path.append(args.path_eg3d_repo) # import eg3d files
if args.attack='GaFaR':
from transformers import GaFaR_InversionTransformer as InversionTransformer
inv_transformer = InversionTransformer(checkpoint=args.checkpoint, eg3d_checkpoint=args.path_eg3d_checkpoint)
elif args.attack='GaFaR_CO':
sys.path.append('./InsightFace-PyTorch') # import detect_align
from transformers import GaFaR_CO_InversionTransformer as InversionTransformer
inv_transformer = InversionTransformer(checkpoint=args.checkpoint, eg3d_checkpoint=args.path_eg3d_checkpoint, FR_system=args.FR_system)
elif args.attack='GaFaR_GS':
sys.path.append('./InsightFace-PyTorch') # import detect_align
from transformers import GaFaR_GS_InversionTransformer as InversionTransformer
inv_transformer = InversionTransformer(checkpoint=args.checkpoint, eg3d_checkpoint=args.path_eg3d_checkpoint, FR_system=args.FR_system)
else:
print(f"[eval pipeline] {args.attack} is not defined!")
inv_transformer = InversionTransformer(checkpoint=args.checkpoint, eg3d_checkpoint=args.path_eg3d_checkpoint)
# The feature extractor is the last element of the pipeline
feature_extractor_target = FR_transformer_target[-1]
inversionAttack_transformer = get_invert_pipeline(
FR_transformer_database, inv_transformer, feature_extractor_target
)
# ================== pipeline ======================
from bob.bio.invert.invertibility_pipeline import InvertBiometricsPipeline
from bob.bio.base.algorithm.distance import Distance
algorithm = Distance()
invert_pipeline = InvertBiometricsPipeline(
FR_transformer_target, inversionAttack_transformer, algorithm
)
dask_client = "single-threaded"
from bob.bio.invert.pipeline import execute_inverted_simple_biometrics
execute_inverted_simple_biometrics(
pipeline=invert_pipeline,
database=database,
dask_client=dask_client,
groups=["dev"],
output="./results/",
write_metadata_scores=True,
checkpoint=True,
dask_partition_size=200,
dask_n_workers=0,
)
\ No newline at end of file
train.py 0 → 100644
"""
Training code for GaFaR (Geometry-aware Face Reconstruction)
Papers:
[TPAMI] Hatef Otroshi Shahreza and Sébastien Marcel, "Comprehensive Vulnerability Evaluation of Face Recognition Systems
to Template Inversion Attacks Via 3D Face Reconstruction", IEEE Transactions on Pattern Analysis and Machine
Intelligence, 2023.
[ICCV] Hatef Otroshi Shahreza and Sébastien Marcel, "Template Inversion Attack against Face Recognition Systems using 3D
Face Reconstruction", IEEE/CVF International Conference on Computer Vision (ICCV), 2023.
"""
import argparse
parser = argparse.ArgumentParser(description='Train face reconstruction network - GaFaR')
parser.add_argument('--path_eg3d_repo', metavar='<path_eg3d_repo>', type= str, default='./eg3d',
help='./eg3d')
parser.add_argument('--path_eg3d_checkpoint', metavar='<path_eg3d_checkpoint>', type= str, default='./ffhq512-128.pkl',
help='./ffhq512-128.pkl`')
parser.add_argument('--path_ffhq_dataset', metavar='<path_ffhq_dataset>', type= str, default='./Flickr-Faces-HQ/images1024x1024',
help='FFHQ directory`')
parser.add_argument('--FR_system', metavar='<FR_system>', type= str, default='ArcFace',
help='ArcFace/ElasticFace (FR system from whose database the templates are leaked)')
parser.add_argument('--FR_loss', metavar='<FR_loss>', type= str, default='ArcFace',
help='ArcFace/ElasticFace (same model as FR_loss in whitebox and a different proxy model in blackbox attacks)')
args = parser.parse_args()
import os,sys
sys.path.append(os.getcwd()) # import src
sys.path.append(args.path_eg3d_repo) # import eg3d files
from camera_utils import LookAtPoseSampler, FOV_to_intrinsics
import pickle
import torch
import torch_utils
import random
import numpy as np
import cv2
from tqdm import tqdm
seed=0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("************ NOTE: The torch device is:", device)
#=================== import Network =====================
path_EG3D = args.path_eg3d_checkpoint
with open(path_EG3D, 'rb') as f:
EG3D = pickle.load(f)['G_ema']
EG3D.to(device)
EG3D.eval()
EG3D_synthesis = EG3D.synthesis
EG3D_mapping = EG3D.mapping
from src.Network import Discriminator, MappingNetwork
model_Discriminator = Discriminator()
model_Discriminator.to(device)
new_mapping = MappingNetwork(z_dim = 16, # Input latent (Z) dimensionality.
c_dim = 512, # Conditioning label (C) dimensionality, 0 = no labels.
w_dim = 512, # Intermediate latent (W) dimensionality.
num_ws = 14, # Number of intermediate latents to output.
num_layers = 2, # Number of mapping layers.
)
new_mapping.to(device)
z_dim_new_mapping = new_mapping.z_dim
z_dim_EG3D = EG3D.z_dim
z_dim_EG3D = 512
#========================================================
#=================== import Dataset ======================
from src.Dataset import MyDataset
from torch.utils.data import DataLoader
training_dataset = MyDataset(FR_system= args.FR_system, train=True, device=device)
testing_dataset = MyDataset(FR_system= args.FR_system, train=False, device=device)
train_dataloader = training_dataset
test_dataloader = DataLoader(testing_dataset, batch_size=18, shuffle=False)
#========================================================
#=================== Optimizers =========================
# ***** optimizer_Generator
for param in new_mapping.parameters():
param.requires_grad = True
# ***** optimizer_Generator
optimizer1_Generator = torch.optim.Adam(new_mapping.parameters(), lr=1e-1)
scheduler1_Generator = torch.optim.lr_scheduler.StepLR(optimizer1_Generator, step_size=3, gamma=0.5)
optimizer2_Generator = torch.optim.Adam(new_mapping.parameters(), lr=1e-1)
scheduler2_Generator = torch.optim.lr_scheduler.StepLR(optimizer2_Generator, step_size=3, gamma=0.5)
optimizer3_Generator = torch.optim.Adam(new_mapping.parameters(), lr=1e-1)
scheduler3_Generator = torch.optim.lr_scheduler.StepLR(optimizer3_Generator, step_size=3, gamma=0.5)
# ***** optimizer_Discriminator
optimizer_Discriminator = torch.optim.Adam(model_Discriminator.parameters(), lr=1e-1)
scheduler_Discriminator = torch.optim.lr_scheduler.StepLR(optimizer_Discriminator, step_size=3, gamma=0.5)
#========================================================
#=================== import Loss ========================
# ***** ID_loss
from src.loss.FaceIDLoss import ID_Loss
ID_loss = ID_Loss(FR_system= args.FR_system, FR_loss= args.FR_loss, device=device)
# ***** Other losses
Pixel_loss = torch.nn.MSELoss()
w_loss = torch.nn.MSELoss()
#========================================================
#=================== Save models and logs ===============
import os
os.makedirs('training_files',exist_ok=True)
os.makedirs('training_files/models',exist_ok=True)
os.makedirs('training_files/Reconstructed_images',exist_ok=True)
os.makedirs('training_files/logs_train',exist_ok=True)
with open('training_files/logs_train/generator.csv','w') as f:
f.write("epoch,Pixel_loss_Gen,W_loss_Gen,ID_loss_Gen,total_loss\n")
with open('training_files/logs_train/log.txt','w') as f:
pass
saved_original_figures = False
#=================== Train ==============================
num_epochs=18
iterations_per_epoch_train=4500
iterations_per_test=150
batch_size = 6
FFHQ_align_mask = train_dataloader.FFHQ_align_mask.repeat(batch_size,1,1,1)
for epoch in range(num_epochs):
print(f'epoch: {epoch}, \t learning rate: {optimizer1_Generator.param_groups[0]["lr"]}')
torch.random.manual_seed(epoch)
for iteration in tqdm(range(iterations_per_epoch_train)):
# =========================================== Teacher-Force using pretrained EG3D ===========================================
# generate images using EG3D
fov_deg = 18.837
cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, torch.tensor([0, 0, 0.2], device=device), radius=2.7, device=device)
intrinsics = FOV_to_intrinsics(fov_deg, device=device)
z = torch.randn([batch_size, z_dim_EG3D]).to(device) # latent codes
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
camera_params = camera_params.repeat(batch_size,1)
w = EG3D_mapping(z, camera_params)
img = EG3D_synthesis(w, camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
# calculate embeddings of images
embedding_db = ID_loss.get_embedding_db(img)
embedding = ID_loss.get_embedding(img)
# ===> now we have (embedding, w, and img)
# Reconstruct image from embedding with same camera params
new_mapping.train()
fov_deg = 18.837
cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, torch.tensor([0, 0, 0.2], device=device), radius=2.7, device=device)
intrinsics = FOV_to_intrinsics(fov_deg, device=device)
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
camera_params = camera_params.repeat(batch_size,1)
z = torch.randn([batch_size, z_dim_new_mapping]).to(device) # latent codes
w_reconstructed = new_mapping(z, embedding_db)
img_reconstructed = EG3D_synthesis(w_reconstructed, camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
# calculate embeddings of images
embedding_reconstructed = ID_loss.get_embedding(img_reconstructed)
### =============== Calculate Loss ============
ID = ID_loss(embedding_reconstructed, embedding)
Pixel = Pixel_loss(img_reconstructed, img)
W = w_loss(w_reconstructed,w)
loss_train_new_mapping = Pixel + ID + W
# ================== backward =================
optimizer1_Generator.zero_grad()
loss_train_new_mapping.backward()
optimizer1_Generator.step()
# ===========================================================================================================================
# =========================================== Trainin using FFHQ dataset ====================================================
#
fov_deg = 18.837 # https://github.com/NVlabs/eg3d/blob/870300f29f4058b8c5088ca79e926762745e40b8/docs/visualizer_guide.md#fov
cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, torch.tensor([0, 0, 0.2], device=device), radius=2.7, device=device)
intrinsics = FOV_to_intrinsics(fov_deg, device=device)
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
camera_params = camera_params.repeat(batch_size,1)
embedding_db, real_image, real_image_HQ = train_dataloader.get_batch(batch_idx=iteration, batch_size=batch_size)
if iteration % 4 == 0:
"""
******************* GAN: Update Discriminator *******************
"""
new_mapping.eval()
model_Discriminator.train()
# Generate batch of latent vectors
z = torch.randn([batch_size, z_dim_new_mapping]).to(device) # latent codes
w_fake = new_mapping(z=z, c=embedding_db).detach()
noise = torch.randn(embedding_db.size(0), z_dim_EG3D, device=device)
w_real = EG3D_mapping(z=noise, c=camera_params).detach()
# ==================forward==================
# disc should give lower score for real and high for gnerated (fake)
output_discriminator_real = model_Discriminator(w_real)
errD_real = output_discriminator_real
output_discriminator_fake = model_Discriminator(w_fake)
errD_fake = (-1) * output_discriminator_fake
loss_GAN_Discriminator = (errD_fake + errD_real).mean()
# ==================backward=================
optimizer_Discriminator.zero_grad()
loss_GAN_Discriminator.backward()
optimizer_Discriminator.step()
for param in model_Discriminator.parameters():
param.data.clamp_(-0.01, 0.01)
if iteration % 2 == 0:
new_mapping.train()
model_Discriminator.eval()
"""
******************* GAN: Update Generator *******************
"""
# Generate batch of latent vectors
z = torch.randn([batch_size, z_dim_new_mapping]).to(device) # latent codes
w_fake = new_mapping(z=z, c=embedding_db)
# ==================forward==================
output_discriminator_fake = model_Discriminator(w_fake)
loss_GAN_Generator = output_discriminator_fake.mean()
# ==================backward=================
optimizer2_Generator.zero_grad()
loss_GAN_Generator.backward()
optimizer2_Generator.step()
# if iteration % 2 == 0:
new_mapping.train()
"""
******************* Train Generator *******************
"""
# ==================forward==================
z = torch.randn([batch_size, z_dim_new_mapping]).to(device) # latent codes
w = new_mapping(z=z, c=embedding_db)
img_reconstructed = EG3D_synthesis(w, c=camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
# calculate embeddings of images
embedding_reconstructed = ID_loss.get_embedding(img_reconstructed)
embedding = ID_loss.get_embedding(real_image_HQ)
ID = ID_loss(embedding_reconstructed, embedding)
Pixel = Pixel_loss( ( torch.clamp(img_reconstructed*FFHQ_align_mask, min=-1, max=1) + 1) / 2.0 ,real_image_HQ*FFHQ_align_mask)
loss_train_Generator = Pixel + ID
# ==================backward=================
optimizer3_Generator.zero_grad()
loss_train_Generator.backward()#(retain_graph=True)
optimizer3_Generator.step()
# ===========================================================================================================================
# ================== log ======================
iteration +=1
if iteration % 200 == 0:
with open('training_files/logs_train/log.txt','a') as f:
f.write(f'epoch:{epoch+1}, \t iteration: {iteration}, \t loss_train_new_mapping:{loss_train_new_mapping.data.item()}\n')
pass
# ====================== Evaluation ===============
new_mapping.eval()
ID_loss_Gen_test = Pixel_loss_Gen_test = W_loss_Gen_test = total_loss_Gen_test = 0
torch.random.manual_seed(1000)
for iteration in range(iterations_per_test):
# ==================forward==================
with torch.no_grad():
# generate images using EG3D
fov_deg = 18.837
cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, torch.tensor([0, 0, 0.2], device=device), radius=2.7, device=device)
intrinsics = FOV_to_intrinsics(fov_deg, device=device)
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
camera_params = camera_params.repeat(batch_size,1)
z = torch.randn([batch_size, z_dim_EG3D]).to(device) # latent codes
w = EG3D_mapping(z, camera_params)
img = EG3D_synthesis(w, camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
# calculate embeddings of images
embedding_db = ID_loss.get_embedding_db(img)
embedding = ID_loss.get_embedding(img)
# Reconstruct image from embedding with same camera params
z = torch.randn([batch_size, z_dim_new_mapping]).to(device) # latent codes
w_reconstructed = new_mapping(z, embedding_db)
img_reconstructed = EG3D_synthesis(w_reconstructed, camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
embedding_reconstructed = ID_loss.get_embedding(img_reconstructed)
ID = ID_loss(embedding_reconstructed, embedding)
# Pixel = Pixel_loss(img_reconstructed, img)
Pixel = Pixel_loss( ( torch.clamp(img_reconstructed*FFHQ_align_mask, min=-1, max=1) + 1) / 2.0 ,img*FFHQ_align_mask)
W = w_loss(w_reconstructed,w)
total_loss_Generator = Pixel + ID + W
####
ID_loss_Gen_test += ID.item()
Pixel_loss_Gen_test += Pixel.item()
W_loss_Gen_test += W.item()
total_loss_Gen_test += total_loss_Generator.item()
with open('training_files/logs_train/generator.csv','a') as f:
f.write(f"{epoch+1}, {Pixel_loss_Gen_test/iteration}, {W_loss_Gen_test/iteration}, {ID_loss_Gen_test/iteration}, {total_loss_Gen_test/iteration}\n")
# generate images using EG3D
fov_deg = 18.837
cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, torch.tensor([0, 0, 0.2], device=device), radius=2.7, device=device)
intrinsics = FOV_to_intrinsics(fov_deg, device=device)
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
camera_params = camera_params.repeat(batch_size,1)
z = torch.randn([batch_size, z_dim_EG3D]).to(device) # latent codes
img = EG3D(z, camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
# calculate embeddings of images
embedding_db = ID_loss.get_embedding_db(img)
# Reconstruct image from embedding with same camera params
z = torch.randn([batch_size, z_dim_new_mapping]).to(device) # latent codes
w = new_mapping(z=z, c=embedding_db)
img_reconstructed = EG3D_synthesis(w, camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
img_reconstructed = img_reconstructed.detach()
if not saved_original_figures:
saved_original_figures = True
for i in range(img_reconstructed.size(0)):
im = img[i].squeeze()
im = (torch.clamp(im, min=-1, max=1) + 1) / 2.0
im = (im.cpu().numpy().transpose(1,2,0))
im = (im * 255).astype(int)
os.makedirs(f'training_files/Reconstructed_images/{i}',exist_ok=True)
cv2.imwrite(f'training_files/Reconstructed_images/{i}/original.jpg',np.array([im[:,:,2],im[:,:,1],im[:,:,0]]).transpose(1,2,0))
for i in range(img_reconstructed.size(0)):
img = img_reconstructed[i].squeeze()
img = (torch.clamp(img, min=-1, max=1) + 1) / 2.0
im = (img.cpu().numpy().transpose(1,2,0))
im = (im * 255).astype(int)
cv2.imwrite(f'training_files/Reconstructed_images/{i}/epoch_{epoch+1}.jpg',np.array([im[:,:,2],im[:,:,1],im[:,:,0]]).transpose(1,2,0))
# *******************************************************
# Save models
torch.save(new_mapping.state_dict(), 'training_files/models/new_mapping_{}.pth'.format(epoch+1))
# torch.save(model_Discriminator.state_dict(), 'training_files/models/Discriminator_{}.pth'.format(epoch+1))
# Update schedulers
scheduler1_Generator.step()
scheduler2_Generator.step()
scheduler3_Generator.step()
scheduler_Discriminator.step()
#========================================================
\ No newline at end of file
import os,sys
import pickle
from src.loss.FaceIDLoss import Crop_and_resize, get_FaceRecognition_transformer
from sklearn.base import TransformerMixin, BaseEstimator
import torch
from bob.pipelines import SampleBatch, Sample, SampleSet
import numpy as np
from camera_utils import LookAtPoseSampler, FOV_to_intrinsics
class GaFaR_InversionTransformer(TransformerMixin, BaseEstimator):
"""
Transforms any :math:`\mathbb{R}^n` into an image :math:`\mathbb{R}^{h \\times w \\times c}`.
Parameters
----------
checkpoint: str
Checkpoint of the image generator
eg3d_checkpoint: str
Checkpoint of the EG3D model
generator:
instance of the generator network
"""
def __init__(self, checkpoint, eg3d_checkpoint, generator=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open(eg3d_checkpoint, 'rb') as f:
EG3D = pickle.load(f)['G_ema']
EG3D.to(self.device)
EG3D.eval()
EG3D_mapping = EG3D.mapping
self.EG3D_synthesis = EG3D.synthesis
if generator is None:
from src.Network import MappingNetwork
self.generator = MappingNetwork(z_dim = 16, # Input latent (Z) dimensionality.
c_dim = 512, # Conditioning label (C) dimensionality, 0 = no labels.
w_dim = 512, # Intermediate latent (W) dimensionality.
num_ws = 14, # Number of intermediate latents to output.
num_layers = 2, # Number of mapping layers.
)
else:
self.generator = generator
# TODO: use the checkpoint variable here
self.generator.load_state_dict(
torch.load(checkpoint, map_location=self.device,)
)
self.generator.eval()
self.generator.to(self.device)
self.checkpoint = checkpoint
self.eg3d_checkpoint = eg3d_checkpoint
fov_deg = 18.837
cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, torch.tensor([0, 0, 0.2], device=self.device), radius=2.7, device=self.device)
intrinsics = FOV_to_intrinsics(fov_deg, device=self.device)
self.camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
def fit(self, X, y=None):
return self
def transform(self, samples):
def _transform(data):
data = data.flatten()
data = np.reshape(data, (1, data.shape[0]))
embedding = torch.Tensor(data).to(self.device)
z = torch.randn([1, self.generator.z_dim]).to(self.device) # latent codes
w = self.generator(z=z, c=embedding)
reconstructed_img = self.EG3D_synthesis(w, self.camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
# noise = torch.randn(embedding.size(0), self.generator.z_dim, device=self.device)
# w = self.generator(z=noise, c=embedding)
# reconstructed_img = self.StyleGAN_synthesis(w)
reconstructed_img = torch.clamp(reconstructed_img, min=-1, max=1)
reconstructed_img = (reconstructed_img + 1) / 2.0
reconstructed_face = Crop_and_resize(reconstructed_img)[0]
return reconstructed_face.cpu().detach().numpy() * 255.0
if isinstance(samples[0], SampleSet):
return [
SampleSet(self.transform(sset.samples), parent=sset,)
for sset in samples
]
else:
return [
Sample(_transform(sample.data), parent=sample,) for sample in samples
]
class GaFaR_CO_InversionTransformer(TransformerMixin, BaseEstimator):
"""
Transforms any :math:`\mathbb{R}^n` into an image :math:`\mathbb{R}^{h \\times w \\times c}`.
Parameters
----------
checkpoint: str
Checkpoint of the image generator
eg3d_checkpoint: str
Checkpoint of the EG3D model
FR_system: str
Face recognition system (database)
generator:
instance of the generator network
"""
def __init__(self, checkpoint, eg3d_checkpoint, FR_system, generator=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open(eg3d_checkpoint, 'rb') as f:
EG3D = pickle.load(f)['G_ema']
EG3D.to(self.device)
EG3D.eval()
EG3D_mapping = EG3D.mapping
self.EG3D_synthesis = EG3D.synthesis
if generator is None:
from src.Network import MappingNetwork
self.generator = MappingNetwork(z_dim = 16, # Input latent (Z) dimensionality.
c_dim = 512, # Conditioning label (C) dimensionality, 0 = no labels.
w_dim = 512, # Intermediate latent (W) dimensionality.
num_ws = 14, # Number of intermediate latents to output.
num_layers = 2, # Number of mapping layers.
)
else:
self.generator = generator
# TODO: use the checkpoint variable here
self.generator.load_state_dict(
torch.load(checkpoint, map_location=self.device,)
)
self.generator.eval()
self.generator.to(self.device)
self.checkpoint = checkpoint
self.eg3d_checkpoint = eg3d_checkpoint
self.FR_system = FR_system
self.fov_deg = 18.837
cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, torch.tensor([0, 0, 0.2], device=self.device), radius=2.7, device=self.device)
intrinsics = FOV_to_intrinsics(self.fov_deg, device=self.device)
self.camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
from detect_align import detectLM_align
self.align = detectLM_align(detector_path= './InsightFace-PyTorch/retinaface/weights/mobilenet0.25_Final.pth', device=self.device)
self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, self.device)
_ = self.FaceRecognition_transformer.transform(torch.zeros([1,3,112,112]).to(self.device))#._load_model(), eval()
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
def fit(self, X, y=None):
return self
def transform(self, samples):
def _transform(data):
data = data.flatten()
data = np.reshape(data, (1, data.shape[0]))
embedding = torch.Tensor(data).to(self.device)
z = torch.randn([1, self.generator.z_dim]).to(self.device) # latent codes
w = self.generator(z=z, c=embedding).detach()
cam_rotation_param = torch.zeros(2,requires_grad=True, device = self.device)
optimizer = torch.optim.Adam([cam_rotation_param],lr=1e-2)
fov_deg = 18.837
cam2world_pose = LookAtPoseSampler.sample(np.pi/2+cam_rotation_param[0], np.pi/2+cam_rotation_param[1], torch.tensor([0, 0, 0.2], device=self.device), radius=2.7, device=self.device)
intrinsics = FOV_to_intrinsics(fov_deg, device=self.device)
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
reconstructed_img = self.EG3D_synthesis(w, camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
reconstructed_img = torch.clamp(reconstructed_img, min=-1, max=1)
reconstructed_img = (reconstructed_img + 1) / 2.0
best_reconstructed_face = Crop_and_resize(reconstructed_img)[0] * 255.0
emb = self.FaceRecognition_transformer.model((best_reconstructed_face.unsqueeze(0) - 127.5) / 128.0 )
loss = torch.nn.MSELoss()(embedding, emb)
best_loss = loss.item()
print('fronatal score',best_loss)
# from ipdb import set_trace
# set_trace()
import time
t0=time.time()
for i in range(121):
print(i)
optimizer.zero_grad()
loss.backward()
optimizer.step()
cam_rotation_param[0].data = torch.clamp(cam_rotation_param[0], min=-np.pi/4, max=np.pi/4)
cam_rotation_param[1].data = torch.clamp(cam_rotation_param[1], min=-np.pi/6, max=np.pi/6)
fov_deg = 18.837
cam2world_pose = LookAtPoseSampler.sample(np.pi/2+cam_rotation_param[0], np.pi/2+cam_rotation_param[1], torch.tensor([0, 0, 0.2], device=self.device), radius=2.7, device=self.device)
intrinsics = FOV_to_intrinsics(fov_deg, device=self.device)
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
reconstructed_img = self.EG3D_synthesis(w, camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
reconstructed_img = (torch.clamp(reconstructed_img, min=-1, max=1) + 1) / 2.0*255.
try:
reconstructed_img_align = self.align(reconstructed_img)
except:
break
emb = self.FaceRecognition_transformer.model((reconstructed_img_align.unsqueeze(0) - 127.5) / 128.0 )
loss = torch.nn.MSELoss()(embedding,emb)
if loss.item()< best_loss:
best_loss = loss.item()
best_reconstructed_face = reconstructed_img_align
print(best_loss)
print(time.time()-t0)
return best_reconstructed_face.cpu().detach().numpy()
if isinstance(samples[0], SampleSet):
return [
SampleSet(self.transform(sset.samples), parent=sset,)
for sset in samples
]
else:
return [
Sample(_transform(sample.data), parent=sample,) for sample in samples
]
class GaFaR_GS_InversionTransformer(TransformerMixin, BaseEstimator):
"""
Transforms any :math:`\mathbb{R}^n` into an image :math:`\mathbb{R}^{h \\times w \\times c}`.
Parameters
----------
checkpoint: str
Checkpoint of the image generator
eg3d_checkpoint: str
Checkpoint of the EG3D model
FR_system: str
Face recognition system (database)
generator:
instance of the generator network
"""
def __init__(self, checkpoint, eg3d_checkpoint, FR_system, generator=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open(eg3d_checkpoint, 'rb') as f:
EG3D = pickle.load(f)['G_ema']
EG3D.to(self.device)
EG3D.eval()
EG3D_mapping = EG3D.mapping
self.EG3D_synthesis = EG3D.synthesis
if generator is None:
from src.Network import MappingNetwork
self.generator = MappingNetwork(z_dim = 16, # Input latent (Z) dimensionality.
c_dim = 512, # Conditioning label (C) dimensionality, 0 = no labels.
w_dim = 512, # Intermediate latent (W) dimensionality.
num_ws = 14, # Number of intermediate latents to output.
num_layers = 2, # Number of mapping layers.
)
else:
self.generator = generator
# TODO: use the checkpoint variable here
self.generator.load_state_dict(
torch.load(checkpoint, map_location=self.device,)
)
self.generator.eval()
self.generator.to(self.device)
self.checkpoint = checkpoint
self.eg3d_checkpoint = eg3d_checkpoint
self.FR_system=FR_system
self.fov_deg = 18.837
cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, torch.tensor([0, 0, 0.2], device=self.device), radius=2.7, device=self.device)
intrinsics = FOV_to_intrinsics(self.fov_deg, device=self.device)
self.camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
from detect_align import detectLM_align
self.align = detectLM_align(detector_path= './InsightFace-PyTorch/retinaface/weights/mobilenet0.25_Final.pth', device=self.device)
self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, self.device)
_ = self.FaceRecognition_transformer.transform(torch.zeros([1,3,112,112]).to(self.device))#._load_model(), eval()
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
def fit(self, X, y=None):
return self
def transform(self, samples):
def _transform(data):
data = data.flatten()
data = np.reshape(data, (1, data.shape[0]))
embedding = torch.Tensor(data).to(self.device)
z = torch.randn([1, self.generator.z_dim]).to(self.device) # latent codes
w = self.generator(z=z, c=embedding)
reconstructed_img = self.EG3D_synthesis(w, self.camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
reconstructed_img = torch.clamp(reconstructed_img, min=-1, max=1)
reconstructed_img = (reconstructed_img + 1) / 2.0
best_reconstructed_face = Crop_and_resize(reconstructed_img)[0] * 255.0
emb = self.FaceRecognition_transformer.model((best_reconstructed_face.unsqueeze(0) - 127.5) / 128.0 )
best_dissim = torch.nn.MSELoss()(embedding, emb)
print('fronatal score',best_dissim)
import time
t0=time.time()
for f in np.linspace(start=-np.pi/4, stop=np.pi/4, num=11, endpoint=True): #yaw
for t in np.linspace(start=-np.pi/6, stop=np.pi/6, num=11, endpoint=True):
print(f,t)
fov_deg = 18.837
cam2world_pose = LookAtPoseSampler.sample(np.pi/2+t, np.pi/2+f, torch.tensor([0, 0, 0.2], device=self.device), radius=2.7, device=self.device)
intrinsics = FOV_to_intrinsics(fov_deg, device=self.device)
camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
reconstructed_img = self.EG3D_synthesis(w, camera_params)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
reconstructed_img = (torch.clamp(reconstructed_img, min=-1, max=1) + 1) / 2.0*255.
try:
reconstructed_img_align = self.align(reconstructed_img)
except:
continue
emb = self.FaceRecognition_transformer.model((reconstructed_img_align.unsqueeze(0) - 127.5) / 128.0 )
dissim = torch.nn.MSELoss()(embedding, emb)
if dissim< best_dissim:
best_dissim = dissim
best_reconstructed_face = reconstructed_img_align
print(best_dissim)
print(time.time()-t0)
return best_reconstructed_face.cpu().detach().numpy()
if isinstance(samples[0], SampleSet):
return [
SampleSet(self.transform(sset.samples), parent=sset,)
for sset in samples
]
else:
return [
Sample(_transform(sample.data), parent=sample,) for sample in samples
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment