Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.bio.face
Commits
b894b829
Commit
b894b829
authored
Jun 09, 2021
by
Tiago de Freitas Pereira
Browse files
Patched pytorch models
parent
174d8289
Pipeline
#51340
passed with stage
in 30 minutes and 35 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/bio/face/embeddings/pytorch.py
View file @
b894b829
...
...
@@ -44,13 +44,16 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
checkpoint_path
=
None
,
config
=
None
,
preprocessor
=
lambda
x
:
x
/
255
,
memory_demanding
=
False
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
checkpoint_path
=
checkpoint_path
self
.
config
=
config
self
.
model
=
None
self
.
preprocessor
=
preprocessor
self
.
memory_demanding
=
memory_demanding
def
transform
(
self
,
X
):
"""__call__(image) -> feature
...
...
@@ -74,7 +77,14 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
X
=
check_array
(
X
,
allow_nd
=
True
)
X
=
torch
.
Tensor
(
X
)
X
=
self
.
preprocessor
(
X
)
return
self
.
model
(
X
).
detach
().
numpy
()
def
_transform
(
X
):
return
self
.
model
(
X
).
detach
().
numpy
()
if
self
.
memory_demanding
:
return
np
.
array
([
_transform
(
x
[
None
,
...])
for
x
in
X
])
else
:
return
_transform
(
X
)
def
__getstate__
(
self
):
# Handling unpicklable objects
...
...
@@ -93,7 +103,7 @@ class AFFFE_2021(PyTorchModel):
"""
def
__init__
(
self
):
def
__init__
(
self
,
memory_demanding
=
False
):
urls
=
[
"https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/AFFFE-42a53f19.tar.gz"
,
...
...
@@ -111,7 +121,9 @@ class AFFFE_2021(PyTorchModel):
config
=
os
.
path
.
join
(
path
,
"AFFFE.py"
)
checkpoint_path
=
os
.
path
.
join
(
path
,
"AFFFE.pth"
)
super
(
AFFFE_2021
,
self
).
__init__
(
checkpoint_path
,
config
)
super
(
AFFFE_2021
,
self
).
__init__
(
checkpoint_path
,
config
,
memory_demanding
=
memory_demanding
)
def
_load_model
(
self
):
...
...
@@ -148,7 +160,7 @@ class IResnet34(PyTorchModel):
ArcFace model (RESNET 34) from Insightface ported to pytorch
"""
def
__init__
(
self
):
def
__init__
(
self
,
memory_demanding
=
False
):
urls
=
[
"https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz"
,
...
...
@@ -161,7 +173,9 @@ class IResnet34(PyTorchModel):
config
=
os
.
path
.
join
(
path
,
"iresnet.py"
)
checkpoint_path
=
os
.
path
.
join
(
path
,
"iresnet34-5b0d0e90.pth"
)
super
(
IResnet34
,
self
).
__init__
(
checkpoint_path
,
config
)
super
(
IResnet34
,
self
).
__init__
(
checkpoint_path
,
config
,
memory_demanding
=
memory_demanding
)
def
_load_model
(
self
):
...
...
@@ -174,7 +188,7 @@ class IResnet50(PyTorchModel):
ArcFace model (RESNET 50) from Insightface ported to pytorch
"""
def
__init__
(
self
):
def
__init__
(
self
,
memory_demanding
=
False
):
filename
=
_get_iresnet_file
()
...
...
@@ -182,7 +196,9 @@ class IResnet50(PyTorchModel):
config
=
os
.
path
.
join
(
path
,
"iresnet.py"
)
checkpoint_path
=
os
.
path
.
join
(
path
,
"iresnet50-7f187506.pth"
)
super
(
IResnet50
,
self
).
__init__
(
checkpoint_path
,
config
)
super
(
IResnet50
,
self
).
__init__
(
checkpoint_path
,
config
,
memory_demanding
=
memory_demanding
)
def
_load_model
(
self
):
...
...
@@ -195,7 +211,7 @@ class IResnet100(PyTorchModel):
ArcFace model (RESNET 100) from Insightface ported to pytorch
"""
def
__init__
(
self
):
def
__init__
(
self
,
memory_demanding
=
False
):
filename
=
_get_iresnet_file
()
...
...
@@ -203,7 +219,9 @@ class IResnet100(PyTorchModel):
config
=
os
.
path
.
join
(
path
,
"iresnet.py"
)
checkpoint_path
=
os
.
path
.
join
(
path
,
"iresnet100-73e07ba7.pth"
)
super
(
IResnet100
,
self
).
__init__
(
checkpoint_path
,
config
)
super
(
IResnet100
,
self
).
__init__
(
checkpoint_path
,
config
,
memory_demanding
=
memory_demanding
)
def
_load_model
(
self
):
...
...
@@ -261,7 +279,7 @@ def iresnet34(annotation_type, fixed_positions=None, memory_demanding=False):
"""
return
iresnet_template
(
embedding
=
IResnet34
(),
embedding
=
IResnet34
(
memory_demanding
=
memory_demanding
),
annotation_type
=
annotation_type
,
fixed_positions
=
fixed_positions
,
)
...
...
@@ -291,7 +309,7 @@ def iresnet50(annotation_type, fixed_positions=None, memory_demanding=False):
"""
return
iresnet_template
(
embedding
=
IResnet50
(),
embedding
=
IResnet50
(
memory_demanding
=
memory_demanding
),
annotation_type
=
annotation_type
,
fixed_positions
=
fixed_positions
,
)
...
...
@@ -321,13 +339,13 @@ def iresnet100(annotation_type, fixed_positions=None, memory_demanding=False):
"""
return
iresnet_template
(
embedding
=
IResnet100
(),
embedding
=
IResnet100
(
memory_demanding
=
memory_demanding
),
annotation_type
=
annotation_type
,
fixed_positions
=
fixed_positions
,
)
def
afffe_baseline
(
annotation_type
,
fixed_positions
=
None
):
def
afffe_baseline
(
annotation_type
,
fixed_positions
=
None
,
memory_demanding
=
False
):
"""
Get the AFFFE pipeline which will crop the face :math:`224
\t
imes 224`
use the :py:class:`AFFFE_2021`
...
...
@@ -353,7 +371,7 @@ def afffe_baseline(annotation_type, fixed_positions=None):
transformer
=
embedding_transformer
(
cropped_image_size
=
cropped_image_size
,
embedding
=
AFFFE_2021
(),
embedding
=
AFFFE_2021
(
memory_demanding
=
memory_demanding
),
cropped_positions
=
cropped_positions
,
fixed_positions
=
fixed_positions
,
color_channel
=
"rgb"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment