Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
B
bob.paper.iccv2023_face_ti
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
bob
bob.paper.iccv2023_face_ti
Commits
e0ffc354
Commit
e0ffc354
authored
1 year ago
by
Hatef OTROSHI
Browse files
Options
Downloads
Patches
Plain Diff
+ train and eval
parent
173844f7
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
eval_SAR_TMR.py
+25
-0
25 additions, 0 deletions
eval_SAR_TMR.py
evaluation_pipeline.py
+130
-0
130 additions, 0 deletions
evaluation_pipeline.py
train.py
+394
-0
394 additions, 0 deletions
train.py
transformers.py
+390
-0
390 additions, 0 deletions
transformers.py
with
939 additions
and
0 deletions
eval_SAR_TMR.py
0 → 100644
+
25
−
0
View file @
e0ffc354
import
bob.bio.base
impostors
,
genuines
=
bob
.
bio
.
base
.
score
.
load
.
split_csv_scores
(
'
results/scores-dev.csv
'
)
_
,
invertes
=
bob
.
bio
.
base
.
score
.
load
.
split_csv_scores
(
'
results/scores_inversion-dev.csv
'
)
import
numpy
as
np
def
fmr_fnmr
(
neg
,
pos
,
threshold
):
fmr
,
fnmr
=
0
,
0
fmr
=
np
.
mean
(
np
.
where
(
np
.
array
(
neg
)
>=
threshold
,
1
,
0
))
fnmr
=
np
.
mean
(
np
.
where
(
np
.
array
(
pos
)
<
threshold
,
1
,
0
))
return
fmr
,
fnmr
from
bob.measure
import
far_threshold
for
FMR
in
[
1e-2
,
1e-3
]:
threshold
=
far_threshold
(
impostors
,
genuines
,
far_value
=
FMR
)
fmr
,
fnmr
=
fmr_fnmr
(
impostors
,
genuines
,
threshold
)
_
,
uSAR
=
fmr_fnmr
(
impostors
,
invertes
,
threshold
)
x
=
fmr
SAR
=
1
-
uSAR
TMR
=
1
-
fnmr
print
(
f
'
FMR:
{
FMR
}
\t
threshold:
{
threshold
}
\t
TMR:
{
TMR
}
, SAR:
{
SAR
}
'
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
evaluation_pipeline.py
0 → 100644
+
130
−
0
View file @
e0ffc354
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
'
Vulnerability evaluation of face reocgnition system against template inversion attack
'
)
parser
.
add_argument
(
'
--FR_system
'
,
metavar
=
'
<FR_system>
'
,
type
=
str
,
default
=
'
ArcFace
'
,
help
=
'
ArcFace/ElasticFace (FR system from whose database the templates are leaked)
'
)
parser
.
add_argument
(
'
--FR_target
'
,
metavar
=
'
<FR_target>
'
,
type
=
str
,
default
=
'
ArcFace
'
,
help
=
'
ArcFace/ElasticFace
'
)
parser
.
add_argument
(
'
--dataset
'
,
metavar
=
'
<dataset>
'
,
type
=
str
,
default
=
'
MOBIO
'
,
help
=
'
MOBIO/LFW
'
)
parser
.
add_argument
(
'
--attack
'
,
metavar
=
'
<attack_method>
'
,
type
=
str
,
default
=
'
GaFaR
'
,
help
=
'
GaFaR/GaFaR_CO/GaFaR_GS
'
)
parser
.
add_argument
(
'
--checkpoint
'
,
metavar
=
'
<checkpoint>
'
,
type
=
str
,
default
=
'
./training_files/models/new_mapping_15.pth
'
,
help
=
'
checkpoint of the new mapping network
'
)
parser
.
add_argument
(
'
--path_eg3d_repo
'
,
metavar
=
'
<path_eg3d_repo>
'
,
type
=
str
,
default
=
'
./eg3d
'
,
help
=
'
./eg3d
'
)
parser
.
add_argument
(
'
--path_eg3d_checkpoint
'
,
metavar
=
'
<path_eg3d_checkpoint>
'
,
type
=
str
,
default
=
'
./ffhq512-128.pkl
'
,
help
=
'
./ffhq512-128.pkl`
'
)
args
=
parser
.
parse_args
()
import
os
,
sys
# ================== dataset ======================
if
args
.
dataset
==
'
MOBIO
'
:
from
bob.bio.face.database
import
MobioDatabase
protocol
=
"
mobile0-male-female
"
database
=
MobioDatabase
(
protocol
=
protocol
)
elif
args
.
dataset
==
'
LFW
'
:
from
bob.bio.face.config.database.lfw_view2
import
database
else
:
print
(
f
"
[eval pipeline]
{
args
.
dataset
}
dataset is not defined!
"
)
# ================== Transformers ==================
if
args
.
FR_system
==
"
ArcFace
"
:
from
bob.bio.face.embeddings.pytorch
import
iresnet100
as
get_pipeline_database
elif
args
.
FR_system
==
"
ElasticFace
"
:
from
bob.bio.face.embeddings.pytorch
import
iresnet100_elastic
as
get_pipeline_database
elif
args
.
FR_system
==
'
AttentionNet92
'
:
from
bob.bio.facexzoo.transformers.pytorch
import
AttentionNet92
as
get_pipeline_database
elif
args
.
FR_system
==
'
HRNet
'
:
from
bob.bio.facexzoo.transformers.pytorch
import
HRNet
as
get_pipeline_database
elif
args
.
FR_system
==
'
RepVGG_B1
'
:
from
bob.bio.facexzoo.transformers.pytorch
import
RepVGG_B1
as
get_pipeline_database
elif
args
.
FR_system
==
'
SwinTransformer_S
'
:
from
bob.bio.facexzoo.transformers.pytorch
import
SwinTransformer_S
as
get_pipeline_database
else
:
print
(
f
"
[eval pipeline]
{
args
.
FR_system
}
is not defined!
"
)
pipeline
=
get_pipeline_database
(
database
.
annotation_type
,
fixed_positions
=
database
.
fixed_positions
,
memory_demanding
=
database
.
memory_demanding
,
)
FR_transformer_database
=
pipeline
.
transformer
if
args
.
FR_target
==
"
ArcFace
"
:
from
bob.bio.face.embeddings.pytorch
import
iresnet100
as
get_pipeline_target
elif
args
.
FR_target
==
"
ElasticFace
"
:
from
bob.bio.face.embeddings.pytorch
import
iresnet100_elastic
as
get_pipeline_target
elif
args
.
FR_target
==
'
AttentionNet92
'
:
from
bob.bio.facexzoo.transformers.pytorch
import
AttentionNet92
as
get_pipeline_target
elif
args
.
FR_target
==
'
HRNet
'
:
from
bob.bio.facexzoo.transformers.pytorch
import
HRNet
as
get_pipeline_target
elif
args
.
FR_target
==
'
RepVGG_B1
'
:
from
bob.bio.facexzoo.transformers.pytorch
import
RepVGG_B1
as
get_pipeline_target
elif
args
.
FR_target
==
'
SwinTransformer_S
'
:
from
bob.bio.facexzoo.transformers.pytorch
import
SwinTransformer_S
as
get_pipeline_target
else
:
print
(
f
"
[eval pipeline]
{
args
.
FR_target
}
is not defined!
"
)
pipeline
=
get_pipeline_target
(
database
.
annotation_type
,
fixed_positions
=
database
.
fixed_positions
,
memory_demanding
=
database
.
memory_demanding
,
)
FR_transformer_target
=
pipeline
.
transformer
# ================== Inversion Transformer ===========
from
bob.pipelines
import
wrap
,
CheckpointWrapper
,
SampleWrapper
from
bob.bio.invert.wrappers
import
get_invert_pipeline
import
os
,
sys
sys
.
path
.
append
(
os
.
getcwd
())
# import src
sys
.
path
.
append
(
args
.
path_eg3d_repo
)
# import eg3d files
if
args
.
attack
=
'
GaFaR
'
:
from
transformers
import
GaFaR_InversionTransformer
as
InversionTransformer
inv_transformer
=
InversionTransformer
(
checkpoint
=
args
.
checkpoint
,
eg3d_checkpoint
=
args
.
path_eg3d_checkpoint
)
elif
args
.
attack
=
'
GaFaR_CO
'
:
sys
.
path
.
append
(
'
./InsightFace-PyTorch
'
)
# import detect_align
from
transformers
import
GaFaR_CO_InversionTransformer
as
InversionTransformer
inv_transformer
=
InversionTransformer
(
checkpoint
=
args
.
checkpoint
,
eg3d_checkpoint
=
args
.
path_eg3d_checkpoint
,
FR_system
=
args
.
FR_system
)
elif
args
.
attack
=
'
GaFaR_GS
'
:
sys
.
path
.
append
(
'
./InsightFace-PyTorch
'
)
# import detect_align
from
transformers
import
GaFaR_GS_InversionTransformer
as
InversionTransformer
inv_transformer
=
InversionTransformer
(
checkpoint
=
args
.
checkpoint
,
eg3d_checkpoint
=
args
.
path_eg3d_checkpoint
,
FR_system
=
args
.
FR_system
)
else
:
print
(
f
"
[eval pipeline]
{
args
.
attack
}
is not defined!
"
)
inv_transformer
=
InversionTransformer
(
checkpoint
=
args
.
checkpoint
,
eg3d_checkpoint
=
args
.
path_eg3d_checkpoint
)
# The feature extractor is the last element of the pipeline
feature_extractor_target
=
FR_transformer_target
[
-
1
]
inversionAttack_transformer
=
get_invert_pipeline
(
FR_transformer_database
,
inv_transformer
,
feature_extractor_target
)
# ================== pipeline ======================
from
bob.bio.invert.invertibility_pipeline
import
InvertBiometricsPipeline
from
bob.bio.base.algorithm.distance
import
Distance
algorithm
=
Distance
()
invert_pipeline
=
InvertBiometricsPipeline
(
FR_transformer_target
,
inversionAttack_transformer
,
algorithm
)
dask_client
=
"
single-threaded
"
from
bob.bio.invert.pipeline
import
execute_inverted_simple_biometrics
execute_inverted_simple_biometrics
(
pipeline
=
invert_pipeline
,
database
=
database
,
dask_client
=
dask_client
,
groups
=
[
"
dev
"
],
output
=
"
./results/
"
,
write_metadata_scores
=
True
,
checkpoint
=
True
,
dask_partition_size
=
200
,
dask_n_workers
=
0
,
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
train.py
0 → 100644
+
394
−
0
View file @
e0ffc354
"""
Training code for GaFaR (Geometry-aware Face Reconstruction)
Papers:
[TPAMI] Hatef Otroshi Shahreza and Sébastien Marcel,
"
Comprehensive Vulnerability Evaluation of Face Recognition Systems
to Template Inversion Attacks Via 3D Face Reconstruction
"
, IEEE Transactions on Pattern Analysis and Machine
Intelligence, 2023.
[ICCV] Hatef Otroshi Shahreza and Sébastien Marcel,
"
Template Inversion Attack against Face Recognition Systems using 3D
Face Reconstruction
"
, IEEE/CVF International Conference on Computer Vision (ICCV), 2023.
"""
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
'
Train face reconstruction network - GaFaR
'
)
parser
.
add_argument
(
'
--path_eg3d_repo
'
,
metavar
=
'
<path_eg3d_repo>
'
,
type
=
str
,
default
=
'
./eg3d
'
,
help
=
'
./eg3d
'
)
parser
.
add_argument
(
'
--path_eg3d_checkpoint
'
,
metavar
=
'
<path_eg3d_checkpoint>
'
,
type
=
str
,
default
=
'
./ffhq512-128.pkl
'
,
help
=
'
./ffhq512-128.pkl`
'
)
parser
.
add_argument
(
'
--path_ffhq_dataset
'
,
metavar
=
'
<path_ffhq_dataset>
'
,
type
=
str
,
default
=
'
./Flickr-Faces-HQ/images1024x1024
'
,
help
=
'
FFHQ directory`
'
)
parser
.
add_argument
(
'
--FR_system
'
,
metavar
=
'
<FR_system>
'
,
type
=
str
,
default
=
'
ArcFace
'
,
help
=
'
ArcFace/ElasticFace (FR system from whose database the templates are leaked)
'
)
parser
.
add_argument
(
'
--FR_loss
'
,
metavar
=
'
<FR_loss>
'
,
type
=
str
,
default
=
'
ArcFace
'
,
help
=
'
ArcFace/ElasticFace (same model as FR_loss in whitebox and a different proxy model in blackbox attacks)
'
)
args
=
parser
.
parse_args
()
import
os
,
sys
sys
.
path
.
append
(
os
.
getcwd
())
# import src
sys
.
path
.
append
(
args
.
path_eg3d_repo
)
# import eg3d files
from
camera_utils
import
LookAtPoseSampler
,
FOV_to_intrinsics
import
pickle
import
torch
import
torch_utils
import
random
import
numpy
as
np
import
cv2
from
tqdm
import
tqdm
seed
=
0
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
device
=
torch
.
device
(
'
cuda
'
if
torch
.
cuda
.
is_available
()
else
'
cpu
'
)
print
(
"
************ NOTE: The torch device is:
"
,
device
)
#=================== import Network =====================
path_EG3D
=
args
.
path_eg3d_checkpoint
with
open
(
path_EG3D
,
'
rb
'
)
as
f
:
EG3D
=
pickle
.
load
(
f
)[
'
G_ema
'
]
EG3D
.
to
(
device
)
EG3D
.
eval
()
EG3D_synthesis
=
EG3D
.
synthesis
EG3D_mapping
=
EG3D
.
mapping
from
src.Network
import
Discriminator
,
MappingNetwork
model_Discriminator
=
Discriminator
()
model_Discriminator
.
to
(
device
)
new_mapping
=
MappingNetwork
(
z_dim
=
16
,
# Input latent (Z) dimensionality.
c_dim
=
512
,
# Conditioning label (C) dimensionality, 0 = no labels.
w_dim
=
512
,
# Intermediate latent (W) dimensionality.
num_ws
=
14
,
# Number of intermediate latents to output.
num_layers
=
2
,
# Number of mapping layers.
)
new_mapping
.
to
(
device
)
z_dim_new_mapping
=
new_mapping
.
z_dim
z_dim_EG3D
=
EG3D
.
z_dim
z_dim_EG3D
=
512
#========================================================
#=================== import Dataset ======================
from
src.Dataset
import
MyDataset
from
torch.utils.data
import
DataLoader
training_dataset
=
MyDataset
(
FR_system
=
args
.
FR_system
,
train
=
True
,
device
=
device
)
testing_dataset
=
MyDataset
(
FR_system
=
args
.
FR_system
,
train
=
False
,
device
=
device
)
train_dataloader
=
training_dataset
test_dataloader
=
DataLoader
(
testing_dataset
,
batch_size
=
18
,
shuffle
=
False
)
#========================================================
#=================== Optimizers =========================
# ***** optimizer_Generator
for
param
in
new_mapping
.
parameters
():
param
.
requires_grad
=
True
# ***** optimizer_Generator
optimizer1_Generator
=
torch
.
optim
.
Adam
(
new_mapping
.
parameters
(),
lr
=
1e-1
)
scheduler1_Generator
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
optimizer1_Generator
,
step_size
=
3
,
gamma
=
0.5
)
optimizer2_Generator
=
torch
.
optim
.
Adam
(
new_mapping
.
parameters
(),
lr
=
1e-1
)
scheduler2_Generator
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
optimizer2_Generator
,
step_size
=
3
,
gamma
=
0.5
)
optimizer3_Generator
=
torch
.
optim
.
Adam
(
new_mapping
.
parameters
(),
lr
=
1e-1
)
scheduler3_Generator
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
optimizer3_Generator
,
step_size
=
3
,
gamma
=
0.5
)
# ***** optimizer_Discriminator
optimizer_Discriminator
=
torch
.
optim
.
Adam
(
model_Discriminator
.
parameters
(),
lr
=
1e-1
)
scheduler_Discriminator
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
optimizer_Discriminator
,
step_size
=
3
,
gamma
=
0.5
)
#========================================================
#=================== import Loss ========================
# ***** ID_loss
from
src.loss.FaceIDLoss
import
ID_Loss
ID_loss
=
ID_Loss
(
FR_system
=
args
.
FR_system
,
FR_loss
=
args
.
FR_loss
,
device
=
device
)
# ***** Other losses
Pixel_loss
=
torch
.
nn
.
MSELoss
()
w_loss
=
torch
.
nn
.
MSELoss
()
#========================================================
#=================== Save models and logs ===============
import
os
os
.
makedirs
(
'
training_files
'
,
exist_ok
=
True
)
os
.
makedirs
(
'
training_files/models
'
,
exist_ok
=
True
)
os
.
makedirs
(
'
training_files/Reconstructed_images
'
,
exist_ok
=
True
)
os
.
makedirs
(
'
training_files/logs_train
'
,
exist_ok
=
True
)
with
open
(
'
training_files/logs_train/generator.csv
'
,
'
w
'
)
as
f
:
f
.
write
(
"
epoch,Pixel_loss_Gen,W_loss_Gen,ID_loss_Gen,total_loss
\n
"
)
with
open
(
'
training_files/logs_train/log.txt
'
,
'
w
'
)
as
f
:
pass
saved_original_figures
=
False
#=================== Train ==============================
num_epochs
=
18
iterations_per_epoch_train
=
4500
iterations_per_test
=
150
batch_size
=
6
FFHQ_align_mask
=
train_dataloader
.
FFHQ_align_mask
.
repeat
(
batch_size
,
1
,
1
,
1
)
for
epoch
in
range
(
num_epochs
):
print
(
f
'
epoch:
{
epoch
}
,
\t
learning rate:
{
optimizer1_Generator
.
param_groups
[
0
][
"
lr
"
]
}
'
)
torch
.
random
.
manual_seed
(
epoch
)
for
iteration
in
tqdm
(
range
(
iterations_per_epoch_train
)):
# =========================================== Teacher-Force using pretrained EG3D ===========================================
# generate images using EG3D
fov_deg
=
18.837
cam2world_pose
=
LookAtPoseSampler
.
sample
(
np
.
pi
/
2
,
np
.
pi
/
2
,
torch
.
tensor
([
0
,
0
,
0.2
],
device
=
device
),
radius
=
2.7
,
device
=
device
)
intrinsics
=
FOV_to_intrinsics
(
fov_deg
,
device
=
device
)
z
=
torch
.
randn
([
batch_size
,
z_dim_EG3D
]).
to
(
device
)
# latent codes
camera_params
=
torch
.
cat
([
cam2world_pose
.
reshape
(
-
1
,
16
),
intrinsics
.
reshape
(
-
1
,
9
)],
1
)
# camera parameters
camera_params
=
camera_params
.
repeat
(
batch_size
,
1
)
w
=
EG3D_mapping
(
z
,
camera_params
)
img
=
EG3D_synthesis
(
w
,
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
# calculate embeddings of images
embedding_db
=
ID_loss
.
get_embedding_db
(
img
)
embedding
=
ID_loss
.
get_embedding
(
img
)
# ===> now we have (embedding, w, and img)
# Reconstruct image from embedding with same camera params
new_mapping
.
train
()
fov_deg
=
18.837
cam2world_pose
=
LookAtPoseSampler
.
sample
(
np
.
pi
/
2
,
np
.
pi
/
2
,
torch
.
tensor
([
0
,
0
,
0.2
],
device
=
device
),
radius
=
2.7
,
device
=
device
)
intrinsics
=
FOV_to_intrinsics
(
fov_deg
,
device
=
device
)
camera_params
=
torch
.
cat
([
cam2world_pose
.
reshape
(
-
1
,
16
),
intrinsics
.
reshape
(
-
1
,
9
)],
1
)
# camera parameters
camera_params
=
camera_params
.
repeat
(
batch_size
,
1
)
z
=
torch
.
randn
([
batch_size
,
z_dim_new_mapping
]).
to
(
device
)
# latent codes
w_reconstructed
=
new_mapping
(
z
,
embedding_db
)
img_reconstructed
=
EG3D_synthesis
(
w_reconstructed
,
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
# calculate embeddings of images
embedding_reconstructed
=
ID_loss
.
get_embedding
(
img_reconstructed
)
### =============== Calculate Loss ============
ID
=
ID_loss
(
embedding_reconstructed
,
embedding
)
Pixel
=
Pixel_loss
(
img_reconstructed
,
img
)
W
=
w_loss
(
w_reconstructed
,
w
)
loss_train_new_mapping
=
Pixel
+
ID
+
W
# ================== backward =================
optimizer1_Generator
.
zero_grad
()
loss_train_new_mapping
.
backward
()
optimizer1_Generator
.
step
()
# ===========================================================================================================================
# =========================================== Trainin using FFHQ dataset ====================================================
#
fov_deg
=
18.837
# https://github.com/NVlabs/eg3d/blob/870300f29f4058b8c5088ca79e926762745e40b8/docs/visualizer_guide.md#fov
cam2world_pose
=
LookAtPoseSampler
.
sample
(
np
.
pi
/
2
,
np
.
pi
/
2
,
torch
.
tensor
([
0
,
0
,
0.2
],
device
=
device
),
radius
=
2.7
,
device
=
device
)
intrinsics
=
FOV_to_intrinsics
(
fov_deg
,
device
=
device
)
camera_params
=
torch
.
cat
([
cam2world_pose
.
reshape
(
-
1
,
16
),
intrinsics
.
reshape
(
-
1
,
9
)],
1
)
# camera parameters
camera_params
=
camera_params
.
repeat
(
batch_size
,
1
)
embedding_db
,
real_image
,
real_image_HQ
=
train_dataloader
.
get_batch
(
batch_idx
=
iteration
,
batch_size
=
batch_size
)
if
iteration
%
4
==
0
:
"""
******************* GAN: Update Discriminator *******************
"""
new_mapping
.
eval
()
model_Discriminator
.
train
()
# Generate batch of latent vectors
z
=
torch
.
randn
([
batch_size
,
z_dim_new_mapping
]).
to
(
device
)
# latent codes
w_fake
=
new_mapping
(
z
=
z
,
c
=
embedding_db
).
detach
()
noise
=
torch
.
randn
(
embedding_db
.
size
(
0
),
z_dim_EG3D
,
device
=
device
)
w_real
=
EG3D_mapping
(
z
=
noise
,
c
=
camera_params
).
detach
()
# ==================forward==================
# disc should give lower score for real and high for gnerated (fake)
output_discriminator_real
=
model_Discriminator
(
w_real
)
errD_real
=
output_discriminator_real
output_discriminator_fake
=
model_Discriminator
(
w_fake
)
errD_fake
=
(
-
1
)
*
output_discriminator_fake
loss_GAN_Discriminator
=
(
errD_fake
+
errD_real
).
mean
()
# ==================backward=================
optimizer_Discriminator
.
zero_grad
()
loss_GAN_Discriminator
.
backward
()
optimizer_Discriminator
.
step
()
for
param
in
model_Discriminator
.
parameters
():
param
.
data
.
clamp_
(
-
0.01
,
0.01
)
if
iteration
%
2
==
0
:
new_mapping
.
train
()
model_Discriminator
.
eval
()
"""
******************* GAN: Update Generator *******************
"""
# Generate batch of latent vectors
z
=
torch
.
randn
([
batch_size
,
z_dim_new_mapping
]).
to
(
device
)
# latent codes
w_fake
=
new_mapping
(
z
=
z
,
c
=
embedding_db
)
# ==================forward==================
output_discriminator_fake
=
model_Discriminator
(
w_fake
)
loss_GAN_Generator
=
output_discriminator_fake
.
mean
()
# ==================backward=================
optimizer2_Generator
.
zero_grad
()
loss_GAN_Generator
.
backward
()
optimizer2_Generator
.
step
()
# if iteration % 2 == 0:
new_mapping
.
train
()
"""
******************* Train Generator *******************
"""
# ==================forward==================
z
=
torch
.
randn
([
batch_size
,
z_dim_new_mapping
]).
to
(
device
)
# latent codes
w
=
new_mapping
(
z
=
z
,
c
=
embedding_db
)
img_reconstructed
=
EG3D_synthesis
(
w
,
c
=
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
# calculate embeddings of images
embedding_reconstructed
=
ID_loss
.
get_embedding
(
img_reconstructed
)
embedding
=
ID_loss
.
get_embedding
(
real_image_HQ
)
ID
=
ID_loss
(
embedding_reconstructed
,
embedding
)
Pixel
=
Pixel_loss
(
(
torch
.
clamp
(
img_reconstructed
*
FFHQ_align_mask
,
min
=-
1
,
max
=
1
)
+
1
)
/
2.0
,
real_image_HQ
*
FFHQ_align_mask
)
loss_train_Generator
=
Pixel
+
ID
# ==================backward=================
optimizer3_Generator
.
zero_grad
()
loss_train_Generator
.
backward
()
#(retain_graph=True)
optimizer3_Generator
.
step
()
# ===========================================================================================================================
# ================== log ======================
iteration
+=
1
if
iteration
%
200
==
0
:
with
open
(
'
training_files/logs_train/log.txt
'
,
'
a
'
)
as
f
:
f
.
write
(
f
'
epoch:
{
epoch
+
1
}
,
\t
iteration:
{
iteration
}
,
\t
loss_train_new_mapping:
{
loss_train_new_mapping
.
data
.
item
()
}
\n
'
)
pass
# ====================== Evaluation ===============
new_mapping
.
eval
()
ID_loss_Gen_test
=
Pixel_loss_Gen_test
=
W_loss_Gen_test
=
total_loss_Gen_test
=
0
torch
.
random
.
manual_seed
(
1000
)
for
iteration
in
range
(
iterations_per_test
):
# ==================forward==================
with
torch
.
no_grad
():
# generate images using EG3D
fov_deg
=
18.837
cam2world_pose
=
LookAtPoseSampler
.
sample
(
np
.
pi
/
2
,
np
.
pi
/
2
,
torch
.
tensor
([
0
,
0
,
0.2
],
device
=
device
),
radius
=
2.7
,
device
=
device
)
intrinsics
=
FOV_to_intrinsics
(
fov_deg
,
device
=
device
)
camera_params
=
torch
.
cat
([
cam2world_pose
.
reshape
(
-
1
,
16
),
intrinsics
.
reshape
(
-
1
,
9
)],
1
)
# camera parameters
camera_params
=
camera_params
.
repeat
(
batch_size
,
1
)
z
=
torch
.
randn
([
batch_size
,
z_dim_EG3D
]).
to
(
device
)
# latent codes
w
=
EG3D_mapping
(
z
,
camera_params
)
img
=
EG3D_synthesis
(
w
,
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
# calculate embeddings of images
embedding_db
=
ID_loss
.
get_embedding_db
(
img
)
embedding
=
ID_loss
.
get_embedding
(
img
)
# Reconstruct image from embedding with same camera params
z
=
torch
.
randn
([
batch_size
,
z_dim_new_mapping
]).
to
(
device
)
# latent codes
w_reconstructed
=
new_mapping
(
z
,
embedding_db
)
img_reconstructed
=
EG3D_synthesis
(
w_reconstructed
,
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
embedding_reconstructed
=
ID_loss
.
get_embedding
(
img_reconstructed
)
ID
=
ID_loss
(
embedding_reconstructed
,
embedding
)
# Pixel = Pixel_loss(img_reconstructed, img)
Pixel
=
Pixel_loss
(
(
torch
.
clamp
(
img_reconstructed
*
FFHQ_align_mask
,
min
=-
1
,
max
=
1
)
+
1
)
/
2.0
,
img
*
FFHQ_align_mask
)
W
=
w_loss
(
w_reconstructed
,
w
)
total_loss_Generator
=
Pixel
+
ID
+
W
####
ID_loss_Gen_test
+=
ID
.
item
()
Pixel_loss_Gen_test
+=
Pixel
.
item
()
W_loss_Gen_test
+=
W
.
item
()
total_loss_Gen_test
+=
total_loss_Generator
.
item
()
with
open
(
'
training_files/logs_train/generator.csv
'
,
'
a
'
)
as
f
:
f
.
write
(
f
"
{
epoch
+
1
}
,
{
Pixel_loss_Gen_test
/
iteration
}
,
{
W_loss_Gen_test
/
iteration
}
,
{
ID_loss_Gen_test
/
iteration
}
,
{
total_loss_Gen_test
/
iteration
}
\n
"
)
# generate images using EG3D
fov_deg
=
18.837
cam2world_pose
=
LookAtPoseSampler
.
sample
(
np
.
pi
/
2
,
np
.
pi
/
2
,
torch
.
tensor
([
0
,
0
,
0.2
],
device
=
device
),
radius
=
2.7
,
device
=
device
)
intrinsics
=
FOV_to_intrinsics
(
fov_deg
,
device
=
device
)
camera_params
=
torch
.
cat
([
cam2world_pose
.
reshape
(
-
1
,
16
),
intrinsics
.
reshape
(
-
1
,
9
)],
1
)
# camera parameters
camera_params
=
camera_params
.
repeat
(
batch_size
,
1
)
z
=
torch
.
randn
([
batch_size
,
z_dim_EG3D
]).
to
(
device
)
# latent codes
img
=
EG3D
(
z
,
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
# calculate embeddings of images
embedding_db
=
ID_loss
.
get_embedding_db
(
img
)
# Reconstruct image from embedding with same camera params
z
=
torch
.
randn
([
batch_size
,
z_dim_new_mapping
]).
to
(
device
)
# latent codes
w
=
new_mapping
(
z
=
z
,
c
=
embedding_db
)
img_reconstructed
=
EG3D_synthesis
(
w
,
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
img_reconstructed
=
img_reconstructed
.
detach
()
if
not
saved_original_figures
:
saved_original_figures
=
True
for
i
in
range
(
img_reconstructed
.
size
(
0
)):
im
=
img
[
i
].
squeeze
()
im
=
(
torch
.
clamp
(
im
,
min
=-
1
,
max
=
1
)
+
1
)
/
2.0
im
=
(
im
.
cpu
().
numpy
().
transpose
(
1
,
2
,
0
))
im
=
(
im
*
255
).
astype
(
int
)
os
.
makedirs
(
f
'
training_files/Reconstructed_images/
{
i
}
'
,
exist_ok
=
True
)
cv2
.
imwrite
(
f
'
training_files/Reconstructed_images/
{
i
}
/original.jpg
'
,
np
.
array
([
im
[:,:,
2
],
im
[:,:,
1
],
im
[:,:,
0
]]).
transpose
(
1
,
2
,
0
))
for
i
in
range
(
img_reconstructed
.
size
(
0
)):
img
=
img_reconstructed
[
i
].
squeeze
()
img
=
(
torch
.
clamp
(
img
,
min
=-
1
,
max
=
1
)
+
1
)
/
2.0
im
=
(
img
.
cpu
().
numpy
().
transpose
(
1
,
2
,
0
))
im
=
(
im
*
255
).
astype
(
int
)
cv2
.
imwrite
(
f
'
training_files/Reconstructed_images/
{
i
}
/epoch_
{
epoch
+
1
}
.jpg
'
,
np
.
array
([
im
[:,:,
2
],
im
[:,:,
1
],
im
[:,:,
0
]]).
transpose
(
1
,
2
,
0
))
# *******************************************************
# Save models
torch
.
save
(
new_mapping
.
state_dict
(),
'
training_files/models/new_mapping_{}.pth
'
.
format
(
epoch
+
1
))
# torch.save(model_Discriminator.state_dict(), 'training_files/models/Discriminator_{}.pth'.format(epoch+1))
# Update schedulers
scheduler1_Generator
.
step
()
scheduler2_Generator
.
step
()
scheduler3_Generator
.
step
()
scheduler_Discriminator
.
step
()
#========================================================
\ No newline at end of file
This diff is collapsed.
Click to expand it.
transformers.py
0 → 100644
+
390
−
0
View file @
e0ffc354
import
os
,
sys
import
pickle
from
src.loss.FaceIDLoss
import
Crop_and_resize
,
get_FaceRecognition_transformer
from
sklearn.base
import
TransformerMixin
,
BaseEstimator
import
torch
from
bob.pipelines
import
SampleBatch
,
Sample
,
SampleSet
import
numpy
as
np
from
camera_utils
import
LookAtPoseSampler
,
FOV_to_intrinsics
class
GaFaR_InversionTransformer
(
TransformerMixin
,
BaseEstimator
):
"""
Transforms any :math:`\mathbb{R}^n` into an image :math:`\mathbb{R}^{h
\\
times w
\\
times c}`.
Parameters
----------
checkpoint: str
Checkpoint of the image generator
eg3d_checkpoint: str
Checkpoint of the EG3D model
generator:
instance of the generator network
"""
def
__init__
(
self
,
checkpoint
,
eg3d_checkpoint
,
generator
=
None
):
self
.
device
=
torch
.
device
(
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
with
open
(
eg3d_checkpoint
,
'
rb
'
)
as
f
:
EG3D
=
pickle
.
load
(
f
)[
'
G_ema
'
]
EG3D
.
to
(
self
.
device
)
EG3D
.
eval
()
EG3D_mapping
=
EG3D
.
mapping
self
.
EG3D_synthesis
=
EG3D
.
synthesis
if
generator
is
None
:
from
src.Network
import
MappingNetwork
self
.
generator
=
MappingNetwork
(
z_dim
=
16
,
# Input latent (Z) dimensionality.
c_dim
=
512
,
# Conditioning label (C) dimensionality, 0 = no labels.
w_dim
=
512
,
# Intermediate latent (W) dimensionality.
num_ws
=
14
,
# Number of intermediate latents to output.
num_layers
=
2
,
# Number of mapping layers.
)
else
:
self
.
generator
=
generator
# TODO: use the checkpoint variable here
self
.
generator
.
load_state_dict
(
torch
.
load
(
checkpoint
,
map_location
=
self
.
device
,)
)
self
.
generator
.
eval
()
self
.
generator
.
to
(
self
.
device
)
self
.
checkpoint
=
checkpoint
self
.
eg3d_checkpoint
=
eg3d_checkpoint
fov_deg
=
18.837
cam2world_pose
=
LookAtPoseSampler
.
sample
(
np
.
pi
/
2
,
np
.
pi
/
2
,
torch
.
tensor
([
0
,
0
,
0.2
],
device
=
self
.
device
),
radius
=
2.7
,
device
=
self
.
device
)
intrinsics
=
FOV_to_intrinsics
(
fov_deg
,
device
=
self
.
device
)
self
.
camera_params
=
torch
.
cat
([
cam2world_pose
.
reshape
(
-
1
,
16
),
intrinsics
.
reshape
(
-
1
,
9
)],
1
)
# camera parameters
def
_more_tags
(
self
):
return
{
"
stateless
"
:
True
,
"
requires_fit
"
:
False
}
def
fit
(
self
,
X
,
y
=
None
):
return
self
def
transform
(
self
,
samples
):
def
_transform
(
data
):
data
=
data
.
flatten
()
data
=
np
.
reshape
(
data
,
(
1
,
data
.
shape
[
0
]))
embedding
=
torch
.
Tensor
(
data
).
to
(
self
.
device
)
z
=
torch
.
randn
([
1
,
self
.
generator
.
z_dim
]).
to
(
self
.
device
)
# latent codes
w
=
self
.
generator
(
z
=
z
,
c
=
embedding
)
reconstructed_img
=
self
.
EG3D_synthesis
(
w
,
self
.
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
# noise = torch.randn(embedding.size(0), self.generator.z_dim, device=self.device)
# w = self.generator(z=noise, c=embedding)
# reconstructed_img = self.StyleGAN_synthesis(w)
reconstructed_img
=
torch
.
clamp
(
reconstructed_img
,
min
=-
1
,
max
=
1
)
reconstructed_img
=
(
reconstructed_img
+
1
)
/
2.0
reconstructed_face
=
Crop_and_resize
(
reconstructed_img
)[
0
]
return
reconstructed_face
.
cpu
().
detach
().
numpy
()
*
255.0
if
isinstance
(
samples
[
0
],
SampleSet
):
return
[
SampleSet
(
self
.
transform
(
sset
.
samples
),
parent
=
sset
,)
for
sset
in
samples
]
else
:
return
[
Sample
(
_transform
(
sample
.
data
),
parent
=
sample
,)
for
sample
in
samples
]
class
GaFaR_CO_InversionTransformer
(
TransformerMixin
,
BaseEstimator
):
"""
Transforms any :math:`\mathbb{R}^n` into an image :math:`\mathbb{R}^{h
\\
times w
\\
times c}`.
Parameters
----------
checkpoint: str
Checkpoint of the image generator
eg3d_checkpoint: str
Checkpoint of the EG3D model
FR_system: str
Face recognition system (database)
generator:
instance of the generator network
"""
def
__init__
(
self
,
checkpoint
,
eg3d_checkpoint
,
FR_system
,
generator
=
None
):
self
.
device
=
torch
.
device
(
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
with
open
(
eg3d_checkpoint
,
'
rb
'
)
as
f
:
EG3D
=
pickle
.
load
(
f
)[
'
G_ema
'
]
EG3D
.
to
(
self
.
device
)
EG3D
.
eval
()
EG3D_mapping
=
EG3D
.
mapping
self
.
EG3D_synthesis
=
EG3D
.
synthesis
if
generator
is
None
:
from
src.Network
import
MappingNetwork
self
.
generator
=
MappingNetwork
(
z_dim
=
16
,
# Input latent (Z) dimensionality.
c_dim
=
512
,
# Conditioning label (C) dimensionality, 0 = no labels.
w_dim
=
512
,
# Intermediate latent (W) dimensionality.
num_ws
=
14
,
# Number of intermediate latents to output.
num_layers
=
2
,
# Number of mapping layers.
)
else
:
self
.
generator
=
generator
# TODO: use the checkpoint variable here
self
.
generator
.
load_state_dict
(
torch
.
load
(
checkpoint
,
map_location
=
self
.
device
,)
)
self
.
generator
.
eval
()
self
.
generator
.
to
(
self
.
device
)
self
.
checkpoint
=
checkpoint
self
.
eg3d_checkpoint
=
eg3d_checkpoint
self
.
FR_system
=
FR_system
self
.
fov_deg
=
18.837
cam2world_pose
=
LookAtPoseSampler
.
sample
(
np
.
pi
/
2
,
np
.
pi
/
2
,
torch
.
tensor
([
0
,
0
,
0.2
],
device
=
self
.
device
),
radius
=
2.7
,
device
=
self
.
device
)
intrinsics
=
FOV_to_intrinsics
(
self
.
fov_deg
,
device
=
self
.
device
)
self
.
camera_params
=
torch
.
cat
([
cam2world_pose
.
reshape
(
-
1
,
16
),
intrinsics
.
reshape
(
-
1
,
9
)],
1
)
# camera parameters
from
detect_align
import
detectLM_align
self
.
align
=
detectLM_align
(
detector_path
=
'
./InsightFace-PyTorch/retinaface/weights/mobilenet0.25_Final.pth
'
,
device
=
self
.
device
)
self
.
FaceRecognition_transformer
=
get_FaceRecognition_transformer
(
FR_system
=
FR_system
,
self
.
device
)
_
=
self
.
FaceRecognition_transformer
.
transform
(
torch
.
zeros
([
1
,
3
,
112
,
112
]).
to
(
self
.
device
))
#._load_model(), eval()
def
_more_tags
(
self
):
return
{
"
stateless
"
:
True
,
"
requires_fit
"
:
False
}
def
fit
(
self
,
X
,
y
=
None
):
return
self
def
transform
(
self
,
samples
):
def
_transform
(
data
):
data
=
data
.
flatten
()
data
=
np
.
reshape
(
data
,
(
1
,
data
.
shape
[
0
]))
embedding
=
torch
.
Tensor
(
data
).
to
(
self
.
device
)
z
=
torch
.
randn
([
1
,
self
.
generator
.
z_dim
]).
to
(
self
.
device
)
# latent codes
w
=
self
.
generator
(
z
=
z
,
c
=
embedding
).
detach
()
cam_rotation_param
=
torch
.
zeros
(
2
,
requires_grad
=
True
,
device
=
self
.
device
)
optimizer
=
torch
.
optim
.
Adam
([
cam_rotation_param
],
lr
=
1e-2
)
fov_deg
=
18.837
cam2world_pose
=
LookAtPoseSampler
.
sample
(
np
.
pi
/
2
+
cam_rotation_param
[
0
],
np
.
pi
/
2
+
cam_rotation_param
[
1
],
torch
.
tensor
([
0
,
0
,
0.2
],
device
=
self
.
device
),
radius
=
2.7
,
device
=
self
.
device
)
intrinsics
=
FOV_to_intrinsics
(
fov_deg
,
device
=
self
.
device
)
camera_params
=
torch
.
cat
([
cam2world_pose
.
reshape
(
-
1
,
16
),
intrinsics
.
reshape
(
-
1
,
9
)],
1
)
# camera parameters
reconstructed_img
=
self
.
EG3D_synthesis
(
w
,
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
reconstructed_img
=
torch
.
clamp
(
reconstructed_img
,
min
=-
1
,
max
=
1
)
reconstructed_img
=
(
reconstructed_img
+
1
)
/
2.0
best_reconstructed_face
=
Crop_and_resize
(
reconstructed_img
)[
0
]
*
255.0
emb
=
self
.
FaceRecognition_transformer
.
model
((
best_reconstructed_face
.
unsqueeze
(
0
)
-
127.5
)
/
128.0
)
loss
=
torch
.
nn
.
MSELoss
()(
embedding
,
emb
)
best_loss
=
loss
.
item
()
print
(
'
fronatal score
'
,
best_loss
)
# from ipdb import set_trace
# set_trace()
import
time
t0
=
time
.
time
()
for
i
in
range
(
121
):
print
(
i
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
cam_rotation_param
[
0
].
data
=
torch
.
clamp
(
cam_rotation_param
[
0
],
min
=-
np
.
pi
/
4
,
max
=
np
.
pi
/
4
)
cam_rotation_param
[
1
].
data
=
torch
.
clamp
(
cam_rotation_param
[
1
],
min
=-
np
.
pi
/
6
,
max
=
np
.
pi
/
6
)
fov_deg
=
18.837
cam2world_pose
=
LookAtPoseSampler
.
sample
(
np
.
pi
/
2
+
cam_rotation_param
[
0
],
np
.
pi
/
2
+
cam_rotation_param
[
1
],
torch
.
tensor
([
0
,
0
,
0.2
],
device
=
self
.
device
),
radius
=
2.7
,
device
=
self
.
device
)
intrinsics
=
FOV_to_intrinsics
(
fov_deg
,
device
=
self
.
device
)
camera_params
=
torch
.
cat
([
cam2world_pose
.
reshape
(
-
1
,
16
),
intrinsics
.
reshape
(
-
1
,
9
)],
1
)
# camera parameters
reconstructed_img
=
self
.
EG3D_synthesis
(
w
,
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
reconstructed_img
=
(
torch
.
clamp
(
reconstructed_img
,
min
=-
1
,
max
=
1
)
+
1
)
/
2.0
*
255.
try
:
reconstructed_img_align
=
self
.
align
(
reconstructed_img
)
except
:
break
emb
=
self
.
FaceRecognition_transformer
.
model
((
reconstructed_img_align
.
unsqueeze
(
0
)
-
127.5
)
/
128.0
)
loss
=
torch
.
nn
.
MSELoss
()(
embedding
,
emb
)
if
loss
.
item
()
<
best_loss
:
best_loss
=
loss
.
item
()
best_reconstructed_face
=
reconstructed_img_align
print
(
best_loss
)
print
(
time
.
time
()
-
t0
)
return
best_reconstructed_face
.
cpu
().
detach
().
numpy
()
if
isinstance
(
samples
[
0
],
SampleSet
):
return
[
SampleSet
(
self
.
transform
(
sset
.
samples
),
parent
=
sset
,)
for
sset
in
samples
]
else
:
return
[
Sample
(
_transform
(
sample
.
data
),
parent
=
sample
,)
for
sample
in
samples
]
class
GaFaR_GS_InversionTransformer
(
TransformerMixin
,
BaseEstimator
):
"""
Transforms any :math:`\mathbb{R}^n` into an image :math:`\mathbb{R}^{h
\\
times w
\\
times c}`.
Parameters
----------
checkpoint: str
Checkpoint of the image generator
eg3d_checkpoint: str
Checkpoint of the EG3D model
FR_system: str
Face recognition system (database)
generator:
instance of the generator network
"""
def
__init__
(
self
,
checkpoint
,
eg3d_checkpoint
,
FR_system
,
generator
=
None
):
self
.
device
=
torch
.
device
(
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
with
open
(
eg3d_checkpoint
,
'
rb
'
)
as
f
:
EG3D
=
pickle
.
load
(
f
)[
'
G_ema
'
]
EG3D
.
to
(
self
.
device
)
EG3D
.
eval
()
EG3D_mapping
=
EG3D
.
mapping
self
.
EG3D_synthesis
=
EG3D
.
synthesis
if
generator
is
None
:
from
src.Network
import
MappingNetwork
self
.
generator
=
MappingNetwork
(
z_dim
=
16
,
# Input latent (Z) dimensionality.
c_dim
=
512
,
# Conditioning label (C) dimensionality, 0 = no labels.
w_dim
=
512
,
# Intermediate latent (W) dimensionality.
num_ws
=
14
,
# Number of intermediate latents to output.
num_layers
=
2
,
# Number of mapping layers.
)
else
:
self
.
generator
=
generator
# TODO: use the checkpoint variable here
self
.
generator
.
load_state_dict
(
torch
.
load
(
checkpoint
,
map_location
=
self
.
device
,)
)
self
.
generator
.
eval
()
self
.
generator
.
to
(
self
.
device
)
self
.
checkpoint
=
checkpoint
self
.
eg3d_checkpoint
=
eg3d_checkpoint
self
.
FR_system
=
FR_system
self
.
fov_deg
=
18.837
cam2world_pose
=
LookAtPoseSampler
.
sample
(
np
.
pi
/
2
,
np
.
pi
/
2
,
torch
.
tensor
([
0
,
0
,
0.2
],
device
=
self
.
device
),
radius
=
2.7
,
device
=
self
.
device
)
intrinsics
=
FOV_to_intrinsics
(
self
.
fov_deg
,
device
=
self
.
device
)
self
.
camera_params
=
torch
.
cat
([
cam2world_pose
.
reshape
(
-
1
,
16
),
intrinsics
.
reshape
(
-
1
,
9
)],
1
)
# camera parameters
from
detect_align
import
detectLM_align
self
.
align
=
detectLM_align
(
detector_path
=
'
./InsightFace-PyTorch/retinaface/weights/mobilenet0.25_Final.pth
'
,
device
=
self
.
device
)
self
.
FaceRecognition_transformer
=
get_FaceRecognition_transformer
(
FR_system
=
FR_system
,
self
.
device
)
_
=
self
.
FaceRecognition_transformer
.
transform
(
torch
.
zeros
([
1
,
3
,
112
,
112
]).
to
(
self
.
device
))
#._load_model(), eval()
def
_more_tags
(
self
):
return
{
"
stateless
"
:
True
,
"
requires_fit
"
:
False
}
def
fit
(
self
,
X
,
y
=
None
):
return
self
def
transform
(
self
,
samples
):
def
_transform
(
data
):
data
=
data
.
flatten
()
data
=
np
.
reshape
(
data
,
(
1
,
data
.
shape
[
0
]))
embedding
=
torch
.
Tensor
(
data
).
to
(
self
.
device
)
z
=
torch
.
randn
([
1
,
self
.
generator
.
z_dim
]).
to
(
self
.
device
)
# latent codes
w
=
self
.
generator
(
z
=
z
,
c
=
embedding
)
reconstructed_img
=
self
.
EG3D_synthesis
(
w
,
self
.
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
reconstructed_img
=
torch
.
clamp
(
reconstructed_img
,
min
=-
1
,
max
=
1
)
reconstructed_img
=
(
reconstructed_img
+
1
)
/
2.0
best_reconstructed_face
=
Crop_and_resize
(
reconstructed_img
)[
0
]
*
255.0
emb
=
self
.
FaceRecognition_transformer
.
model
((
best_reconstructed_face
.
unsqueeze
(
0
)
-
127.5
)
/
128.0
)
best_dissim
=
torch
.
nn
.
MSELoss
()(
embedding
,
emb
)
print
(
'
fronatal score
'
,
best_dissim
)
import
time
t0
=
time
.
time
()
for
f
in
np
.
linspace
(
start
=-
np
.
pi
/
4
,
stop
=
np
.
pi
/
4
,
num
=
11
,
endpoint
=
True
):
#yaw
for
t
in
np
.
linspace
(
start
=-
np
.
pi
/
6
,
stop
=
np
.
pi
/
6
,
num
=
11
,
endpoint
=
True
):
print
(
f
,
t
)
fov_deg
=
18.837
cam2world_pose
=
LookAtPoseSampler
.
sample
(
np
.
pi
/
2
+
t
,
np
.
pi
/
2
+
f
,
torch
.
tensor
([
0
,
0
,
0.2
],
device
=
self
.
device
),
radius
=
2.7
,
device
=
self
.
device
)
intrinsics
=
FOV_to_intrinsics
(
fov_deg
,
device
=
self
.
device
)
camera_params
=
torch
.
cat
([
cam2world_pose
.
reshape
(
-
1
,
16
),
intrinsics
.
reshape
(
-
1
,
9
)],
1
)
# camera parameters
reconstructed_img
=
self
.
EG3D_synthesis
(
w
,
camera_params
)[
'
image
'
]
# NCHW, float32, dynamic range [-1, +1], no truncation
reconstructed_img
=
(
torch
.
clamp
(
reconstructed_img
,
min
=-
1
,
max
=
1
)
+
1
)
/
2.0
*
255.
try
:
reconstructed_img_align
=
self
.
align
(
reconstructed_img
)
except
:
continue
emb
=
self
.
FaceRecognition_transformer
.
model
((
reconstructed_img_align
.
unsqueeze
(
0
)
-
127.5
)
/
128.0
)
dissim
=
torch
.
nn
.
MSELoss
()(
embedding
,
emb
)
if
dissim
<
best_dissim
:
best_dissim
=
dissim
best_reconstructed_face
=
reconstructed_img_align
print
(
best_dissim
)
print
(
time
.
time
()
-
t0
)
return
best_reconstructed_face
.
cpu
().
detach
().
numpy
()
if
isinstance
(
samples
[
0
],
SampleSet
):
return
[
SampleSet
(
self
.
transform
(
sset
.
samples
),
parent
=
sset
,)
for
sset
in
samples
]
else
:
return
[
Sample
(
_transform
(
sample
.
data
),
parent
=
sample
,)
for
sample
in
samples
]
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment