Skip to content
Snippets Groups Projects
Commit b894b829 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Patched pytorch models

parent 174d8289
No related branches found
No related tags found
1 merge request!112Feature extractors
Pipeline #51340 passed
......@@ -44,13 +44,16 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
checkpoint_path=None,
config=None,
preprocessor=lambda x: x / 255,
memory_demanding=False,
**kwargs
):
super().__init__(**kwargs)
self.checkpoint_path = checkpoint_path
self.config = config
self.model = None
self.preprocessor = preprocessor
self.memory_demanding = memory_demanding
def transform(self, X):
"""__call__(image) -> feature
......@@ -74,7 +77,14 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
X = check_array(X, allow_nd=True)
X = torch.Tensor(X)
X = self.preprocessor(X)
return self.model(X).detach().numpy()
def _transform(X):
return self.model(X).detach().numpy()
if self.memory_demanding:
return np.array([_transform(x[None, ...]) for x in X])
else:
return _transform(X)
def __getstate__(self):
# Handling unpicklable objects
......@@ -93,7 +103,7 @@ class AFFFE_2021(PyTorchModel):
"""
def __init__(self):
def __init__(self, memory_demanding=False):
urls = [
"https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/AFFFE-42a53f19.tar.gz",
......@@ -111,7 +121,9 @@ class AFFFE_2021(PyTorchModel):
config = os.path.join(path, "AFFFE.py")
checkpoint_path = os.path.join(path, "AFFFE.pth")
super(AFFFE_2021, self).__init__(checkpoint_path, config)
super(AFFFE_2021, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding
)
def _load_model(self):
......@@ -148,7 +160,7 @@ class IResnet34(PyTorchModel):
ArcFace model (RESNET 34) from Insightface ported to pytorch
"""
def __init__(self):
def __init__(self, memory_demanding=False):
urls = [
"https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
......@@ -161,7 +173,9 @@ class IResnet34(PyTorchModel):
config = os.path.join(path, "iresnet.py")
checkpoint_path = os.path.join(path, "iresnet34-5b0d0e90.pth")
super(IResnet34, self).__init__(checkpoint_path, config)
super(IResnet34, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding
)
def _load_model(self):
......@@ -174,7 +188,7 @@ class IResnet50(PyTorchModel):
ArcFace model (RESNET 50) from Insightface ported to pytorch
"""
def __init__(self):
def __init__(self, memory_demanding=False):
filename = _get_iresnet_file()
......@@ -182,7 +196,9 @@ class IResnet50(PyTorchModel):
config = os.path.join(path, "iresnet.py")
checkpoint_path = os.path.join(path, "iresnet50-7f187506.pth")
super(IResnet50, self).__init__(checkpoint_path, config)
super(IResnet50, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding
)
def _load_model(self):
......@@ -195,7 +211,7 @@ class IResnet100(PyTorchModel):
ArcFace model (RESNET 100) from Insightface ported to pytorch
"""
def __init__(self):
def __init__(self, memory_demanding=False):
filename = _get_iresnet_file()
......@@ -203,7 +219,9 @@ class IResnet100(PyTorchModel):
config = os.path.join(path, "iresnet.py")
checkpoint_path = os.path.join(path, "iresnet100-73e07ba7.pth")
super(IResnet100, self).__init__(checkpoint_path, config)
super(IResnet100, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding
)
def _load_model(self):
......@@ -261,7 +279,7 @@ def iresnet34(annotation_type, fixed_positions=None, memory_demanding=False):
"""
return iresnet_template(
embedding=IResnet34(),
embedding=IResnet34(memory_demanding=memory_demanding),
annotation_type=annotation_type,
fixed_positions=fixed_positions,
)
......@@ -291,7 +309,7 @@ def iresnet50(annotation_type, fixed_positions=None, memory_demanding=False):
"""
return iresnet_template(
embedding=IResnet50(),
embedding=IResnet50(memory_demanding=memory_demanding),
annotation_type=annotation_type,
fixed_positions=fixed_positions,
)
......@@ -321,13 +339,13 @@ def iresnet100(annotation_type, fixed_positions=None, memory_demanding=False):
"""
return iresnet_template(
embedding=IResnet100(),
embedding=IResnet100(memory_demanding=memory_demanding),
annotation_type=annotation_type,
fixed_positions=fixed_positions,
)
def afffe_baseline(annotation_type, fixed_positions=None):
def afffe_baseline(annotation_type, fixed_positions=None, memory_demanding=False):
"""
Get the AFFFE pipeline which will crop the face :math:`224 \times 224`
use the :py:class:`AFFFE_2021`
......@@ -353,7 +371,7 @@ def afffe_baseline(annotation_type, fixed_positions=None):
transformer = embedding_transformer(
cropped_image_size=cropped_image_size,
embedding=AFFFE_2021(),
embedding=AFFFE_2021(memory_demanding=memory_demanding),
cropped_positions=cropped_positions,
fixed_positions=fixed_positions,
color_channel="rgb",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment