From b894b8299d42af643e0be75d68b30d7d30450163 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Wed, 9 Jun 2021 13:53:22 +0200 Subject: [PATCH] Patched pytorch models --- bob/bio/face/embeddings/pytorch.py | 46 +++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/bob/bio/face/embeddings/pytorch.py b/bob/bio/face/embeddings/pytorch.py index 601ac06a..2007865d 100644 --- a/bob/bio/face/embeddings/pytorch.py +++ b/bob/bio/face/embeddings/pytorch.py @@ -44,13 +44,16 @@ class PyTorchModel(TransformerMixin, BaseEstimator): checkpoint_path=None, config=None, preprocessor=lambda x: x / 255, + memory_demanding=False, **kwargs ): + super().__init__(**kwargs) self.checkpoint_path = checkpoint_path self.config = config self.model = None self.preprocessor = preprocessor + self.memory_demanding = memory_demanding def transform(self, X): """__call__(image) -> feature @@ -74,7 +77,14 @@ class PyTorchModel(TransformerMixin, BaseEstimator): X = check_array(X, allow_nd=True) X = torch.Tensor(X) X = self.preprocessor(X) - return self.model(X).detach().numpy() + + def _transform(X): + return self.model(X).detach().numpy() + + if self.memory_demanding: + return np.array([_transform(x[None, ...]) for x in X]) + else: + return _transform(X) def __getstate__(self): # Handling unpicklable objects @@ -93,7 +103,7 @@ class AFFFE_2021(PyTorchModel): """ - def __init__(self): + def __init__(self, memory_demanding=False): urls = [ "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/AFFFE-42a53f19.tar.gz", @@ -111,7 +121,9 @@ class AFFFE_2021(PyTorchModel): config = os.path.join(path, "AFFFE.py") checkpoint_path = os.path.join(path, "AFFFE.pth") - super(AFFFE_2021, self).__init__(checkpoint_path, config) + super(AFFFE_2021, self).__init__( + checkpoint_path, config, memory_demanding=memory_demanding + ) def _load_model(self): @@ -148,7 +160,7 @@ class IResnet34(PyTorchModel): ArcFace model (RESNET 34) from Insightface ported to pytorch """ - def __init__(self): + def __init__(self, memory_demanding=False): urls = [ "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz", @@ -161,7 +173,9 @@ class IResnet34(PyTorchModel): config = os.path.join(path, "iresnet.py") checkpoint_path = os.path.join(path, "iresnet34-5b0d0e90.pth") - super(IResnet34, self).__init__(checkpoint_path, config) + super(IResnet34, self).__init__( + checkpoint_path, config, memory_demanding=memory_demanding + ) def _load_model(self): @@ -174,7 +188,7 @@ class IResnet50(PyTorchModel): ArcFace model (RESNET 50) from Insightface ported to pytorch """ - def __init__(self): + def __init__(self, memory_demanding=False): filename = _get_iresnet_file() @@ -182,7 +196,9 @@ class IResnet50(PyTorchModel): config = os.path.join(path, "iresnet.py") checkpoint_path = os.path.join(path, "iresnet50-7f187506.pth") - super(IResnet50, self).__init__(checkpoint_path, config) + super(IResnet50, self).__init__( + checkpoint_path, config, memory_demanding=memory_demanding + ) def _load_model(self): @@ -195,7 +211,7 @@ class IResnet100(PyTorchModel): ArcFace model (RESNET 100) from Insightface ported to pytorch """ - def __init__(self): + def __init__(self, memory_demanding=False): filename = _get_iresnet_file() @@ -203,7 +219,9 @@ class IResnet100(PyTorchModel): config = os.path.join(path, "iresnet.py") checkpoint_path = os.path.join(path, "iresnet100-73e07ba7.pth") - super(IResnet100, self).__init__(checkpoint_path, config) + super(IResnet100, self).__init__( + checkpoint_path, config, memory_demanding=memory_demanding + ) def _load_model(self): @@ -261,7 +279,7 @@ def iresnet34(annotation_type, fixed_positions=None, memory_demanding=False): """ return iresnet_template( - embedding=IResnet34(), + embedding=IResnet34(memory_demanding=memory_demanding), annotation_type=annotation_type, fixed_positions=fixed_positions, ) @@ -291,7 +309,7 @@ def iresnet50(annotation_type, fixed_positions=None, memory_demanding=False): """ return iresnet_template( - embedding=IResnet50(), + embedding=IResnet50(memory_demanding=memory_demanding), annotation_type=annotation_type, fixed_positions=fixed_positions, ) @@ -321,13 +339,13 @@ def iresnet100(annotation_type, fixed_positions=None, memory_demanding=False): """ return iresnet_template( - embedding=IResnet100(), + embedding=IResnet100(memory_demanding=memory_demanding), annotation_type=annotation_type, fixed_positions=fixed_positions, ) -def afffe_baseline(annotation_type, fixed_positions=None): +def afffe_baseline(annotation_type, fixed_positions=None, memory_demanding=False): """ Get the AFFFE pipeline which will crop the face :math:`224 \times 224` use the :py:class:`AFFFE_2021` @@ -353,7 +371,7 @@ def afffe_baseline(annotation_type, fixed_positions=None): transformer = embedding_transformer( cropped_image_size=cropped_image_size, - embedding=AFFFE_2021(), + embedding=AFFFE_2021(memory_demanding=memory_demanding), cropped_positions=cropped_positions, fixed_positions=fixed_positions, color_channel="rgb", -- GitLab