Commit b894b829 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Patched pytorch models

parent 174d8289
Pipeline #51340 passed with stage
in 30 minutes and 35 seconds
......@@ -44,13 +44,16 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
checkpoint_path=None,
config=None,
preprocessor=lambda x: x / 255,
memory_demanding=False,
**kwargs
):
super().__init__(**kwargs)
self.checkpoint_path = checkpoint_path
self.config = config
self.model = None
self.preprocessor = preprocessor
self.memory_demanding = memory_demanding
def transform(self, X):
"""__call__(image) -> feature
......@@ -74,7 +77,14 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
X = check_array(X, allow_nd=True)
X = torch.Tensor(X)
X = self.preprocessor(X)
return self.model(X).detach().numpy()
def _transform(X):
return self.model(X).detach().numpy()
if self.memory_demanding:
return np.array([_transform(x[None, ...]) for x in X])
else:
return _transform(X)
def __getstate__(self):
# Handling unpicklable objects
......@@ -93,7 +103,7 @@ class AFFFE_2021(PyTorchModel):
"""
def __init__(self):
def __init__(self, memory_demanding=False):
urls = [
"https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/AFFFE-42a53f19.tar.gz",
......@@ -111,7 +121,9 @@ class AFFFE_2021(PyTorchModel):
config = os.path.join(path, "AFFFE.py")
checkpoint_path = os.path.join(path, "AFFFE.pth")
super(AFFFE_2021, self).__init__(checkpoint_path, config)
super(AFFFE_2021, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding
)
def _load_model(self):
......@@ -148,7 +160,7 @@ class IResnet34(PyTorchModel):
ArcFace model (RESNET 34) from Insightface ported to pytorch
"""
def __init__(self):
def __init__(self, memory_demanding=False):
urls = [
"https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
......@@ -161,7 +173,9 @@ class IResnet34(PyTorchModel):
config = os.path.join(path, "iresnet.py")
checkpoint_path = os.path.join(path, "iresnet34-5b0d0e90.pth")
super(IResnet34, self).__init__(checkpoint_path, config)
super(IResnet34, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding
)
def _load_model(self):
......@@ -174,7 +188,7 @@ class IResnet50(PyTorchModel):
ArcFace model (RESNET 50) from Insightface ported to pytorch
"""
def __init__(self):
def __init__(self, memory_demanding=False):
filename = _get_iresnet_file()
......@@ -182,7 +196,9 @@ class IResnet50(PyTorchModel):
config = os.path.join(path, "iresnet.py")
checkpoint_path = os.path.join(path, "iresnet50-7f187506.pth")
super(IResnet50, self).__init__(checkpoint_path, config)
super(IResnet50, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding
)
def _load_model(self):
......@@ -195,7 +211,7 @@ class IResnet100(PyTorchModel):
ArcFace model (RESNET 100) from Insightface ported to pytorch
"""
def __init__(self):
def __init__(self, memory_demanding=False):
filename = _get_iresnet_file()
......@@ -203,7 +219,9 @@ class IResnet100(PyTorchModel):
config = os.path.join(path, "iresnet.py")
checkpoint_path = os.path.join(path, "iresnet100-73e07ba7.pth")
super(IResnet100, self).__init__(checkpoint_path, config)
super(IResnet100, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding
)
def _load_model(self):
......@@ -261,7 +279,7 @@ def iresnet34(annotation_type, fixed_positions=None, memory_demanding=False):
"""
return iresnet_template(
embedding=IResnet34(),
embedding=IResnet34(memory_demanding=memory_demanding),
annotation_type=annotation_type,
fixed_positions=fixed_positions,
)
......@@ -291,7 +309,7 @@ def iresnet50(annotation_type, fixed_positions=None, memory_demanding=False):
"""
return iresnet_template(
embedding=IResnet50(),
embedding=IResnet50(memory_demanding=memory_demanding),
annotation_type=annotation_type,
fixed_positions=fixed_positions,
)
......@@ -321,13 +339,13 @@ def iresnet100(annotation_type, fixed_positions=None, memory_demanding=False):
"""
return iresnet_template(
embedding=IResnet100(),
embedding=IResnet100(memory_demanding=memory_demanding),
annotation_type=annotation_type,
fixed_positions=fixed_positions,
)
def afffe_baseline(annotation_type, fixed_positions=None):
def afffe_baseline(annotation_type, fixed_positions=None, memory_demanding=False):
"""
Get the AFFFE pipeline which will crop the face :math:`224 \times 224`
use the :py:class:`AFFFE_2021`
......@@ -353,7 +371,7 @@ def afffe_baseline(annotation_type, fixed_positions=None):
transformer = embedding_transformer(
cropped_image_size=cropped_image_size,
embedding=AFFFE_2021(),
embedding=AFFFE_2021(memory_demanding=memory_demanding),
cropped_positions=cropped_positions,
fixed_positions=fixed_positions,
color_channel="rgb",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment