Skip to content
Snippets Groups Projects
Commit 8c9349e4 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'fix-pytorch-models' into 'master'

Fix Pytorch models

See merge request !130
parents 5a81dde3 bfc285fb
Branches
Tags
1 merge request!130Fix Pytorch models
Pipeline #52480 passed
...@@ -3,21 +3,18 @@ ...@@ -3,21 +3,18 @@
# Yu Linghu & Xinyi Zhang <yu.linghu@uzh.ch, xinyi.zhang@uzh.ch> # Yu Linghu & Xinyi Zhang <yu.linghu@uzh.ch, xinyi.zhang@uzh.ch>
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> # Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.utils import check_array
import numpy as np
import imp import imp
import os import os
from bob.extension.download import get_file
from bob.bio.face.utils import (
dnn_default_cropping,
embedding_transformer,
)
from bob.bio.base.pipelines.vanilla_biometrics import ( import numpy as np
Distance, from bob.bio.base.pipelines.vanilla_biometrics import Distance
VanillaBiometricsPipeline, from bob.bio.base.pipelines.vanilla_biometrics import VanillaBiometricsPipeline
) from bob.bio.face.utils import dnn_default_cropping
from bob.bio.face.utils import embedding_transformer
from bob.extension.download import get_file
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
from sklearn.utils import check_array
class PyTorchModel(TransformerMixin, BaseEstimator): class PyTorchModel(TransformerMixin, BaseEstimator):
...@@ -45,7 +42,8 @@ class PyTorchModel(TransformerMixin, BaseEstimator): ...@@ -45,7 +42,8 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
config=None, config=None,
preprocessor=lambda x: x / 255, preprocessor=lambda x: x / 255,
memory_demanding=False, memory_demanding=False,
**kwargs device=None,
**kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -54,6 +52,7 @@ class PyTorchModel(TransformerMixin, BaseEstimator): ...@@ -54,6 +52,7 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
self.model = None self.model = None
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.memory_demanding = memory_demanding self.memory_demanding = memory_demanding
self.device = device
def transform(self, X): def transform(self, X):
"""__call__(image) -> feature """__call__(image) -> feature
...@@ -76,10 +75,12 @@ class PyTorchModel(TransformerMixin, BaseEstimator): ...@@ -76,10 +75,12 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
self._load_model() self._load_model()
X = check_array(X, allow_nd=True) X = check_array(X, allow_nd=True)
X = torch.Tensor(X) X = torch.Tensor(X)
X = self.preprocessor(X) with torch.no_grad():
X = self.preprocessor(X)
def _transform(X): def _transform(X):
return self.model(X).detach().numpy() with torch.no_grad():
return self.model(X).cpu().detach().numpy()
if self.memory_demanding: if self.memory_demanding:
return np.array([_transform(x[None, ...]) for x in X]) return np.array([_transform(x[None, ...]) for x in X])
...@@ -96,14 +97,25 @@ class PyTorchModel(TransformerMixin, BaseEstimator): ...@@ -96,14 +97,25 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
def _more_tags(self): def _more_tags(self):
return {"stateless": True, "requires_fit": False} return {"stateless": True, "requires_fit": False}
def place_model_on_device(self, device=None):
import torch
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
if self.model is not None:
self.model.to(device)
class AFFFE_2021(PyTorchModel): class AFFFE_2021(PyTorchModel):
""" """
AFFFE Pytorch network that extracts 1000-dimensional features, trained by Manuel Gunther, as described in [LGB18]_ AFFFE Pytorch network that extracts 1000-dimensional features, trained by Manuel Gunther, as described in [LGB18]_
""" """
def __init__(self, memory_demanding=False): def __init__(self, memory_demanding=False, device=None, **kwargs):
urls = [ urls = [
"https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/AFFFE-42a53f19.tar.gz", "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/AFFFE-42a53f19.tar.gz",
...@@ -122,22 +134,21 @@ class AFFFE_2021(PyTorchModel): ...@@ -122,22 +134,21 @@ class AFFFE_2021(PyTorchModel):
checkpoint_path = os.path.join(path, "AFFFE.pth") checkpoint_path = os.path.join(path, "AFFFE.pth")
super(AFFFE_2021, self).__init__( super(AFFFE_2021, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding checkpoint_path,
config,
memory_demanding=memory_demanding,
device=device,
**kwargs,
) )
def _load_model(self): def _load_model(self):
import torch import torch
MainModel = imp.load_source("MainModel", self.config) MainModel = imp.load_source("MainModel", self.config)
network = torch.load(self.checkpoint_path) self.model = torch.load(self.checkpoint_path, map_location=self.device)
network.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network.to(device) self.model.eval()
self.place_model_on_device()
self.model = network
def _get_iresnet_file(): def _get_iresnet_file():
...@@ -161,14 +172,13 @@ class IResnet34(PyTorchModel): ...@@ -161,14 +172,13 @@ class IResnet34(PyTorchModel):
""" """
def __init__( def __init__(
self, preprocessor=lambda x: (x - 127.5) / 128.0, memory_demanding=False self,
preprocessor=lambda x: (x - 127.5) / 128.0,
memory_demanding=False,
device=None,
**kwargs,
): ):
urls = [
"https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
"http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
]
filename = _get_iresnet_file() filename = _get_iresnet_file()
path = os.path.dirname(filename) path = os.path.dirname(filename)
...@@ -180,6 +190,8 @@ class IResnet34(PyTorchModel): ...@@ -180,6 +190,8 @@ class IResnet34(PyTorchModel):
config, config,
memory_demanding=memory_demanding, memory_demanding=memory_demanding,
preprocessor=preprocessor, preprocessor=preprocessor,
device=device,
**kwargs,
) )
def _load_model(self): def _load_model(self):
...@@ -187,6 +199,9 @@ class IResnet34(PyTorchModel): ...@@ -187,6 +199,9 @@ class IResnet34(PyTorchModel):
model = imp.load_source("module", self.config).iresnet34(self.checkpoint_path) model = imp.load_source("module", self.config).iresnet34(self.checkpoint_path)
self.model = model self.model = model
self.model.eval()
self.place_model_on_device()
class IResnet50(PyTorchModel): class IResnet50(PyTorchModel):
""" """
...@@ -194,7 +209,11 @@ class IResnet50(PyTorchModel): ...@@ -194,7 +209,11 @@ class IResnet50(PyTorchModel):
""" """
def __init__( def __init__(
self, preprocessor=lambda x: (x - 127.5) / 128.0, memory_demanding=False self,
preprocessor=lambda x: (x - 127.5) / 128.0,
memory_demanding=False,
device=None,
**kwargs,
): ):
filename = _get_iresnet_file() filename = _get_iresnet_file()
...@@ -208,6 +227,8 @@ class IResnet50(PyTorchModel): ...@@ -208,6 +227,8 @@ class IResnet50(PyTorchModel):
config, config,
memory_demanding=memory_demanding, memory_demanding=memory_demanding,
preprocessor=preprocessor, preprocessor=preprocessor,
device=device,
**kwargs,
) )
def _load_model(self): def _load_model(self):
...@@ -215,6 +236,9 @@ class IResnet50(PyTorchModel): ...@@ -215,6 +236,9 @@ class IResnet50(PyTorchModel):
model = imp.load_source("module", self.config).iresnet50(self.checkpoint_path) model = imp.load_source("module", self.config).iresnet50(self.checkpoint_path)
self.model = model self.model = model
self.model.eval()
self.place_model_on_device()
class IResnet100(PyTorchModel): class IResnet100(PyTorchModel):
""" """
...@@ -222,7 +246,11 @@ class IResnet100(PyTorchModel): ...@@ -222,7 +246,11 @@ class IResnet100(PyTorchModel):
""" """
def __init__( def __init__(
self, preprocessor=lambda x: (x - 127.5) / 128.0, memory_demanding=False self,
preprocessor=lambda x: (x - 127.5) / 128.0,
memory_demanding=False,
device=None,
**kwargs,
): ):
filename = _get_iresnet_file() filename = _get_iresnet_file()
...@@ -236,6 +264,8 @@ class IResnet100(PyTorchModel): ...@@ -236,6 +264,8 @@ class IResnet100(PyTorchModel):
config, config,
memory_demanding=memory_demanding, memory_demanding=memory_demanding,
preprocessor=preprocessor, preprocessor=preprocessor,
device=device,
**kwargs,
) )
def _load_model(self): def _load_model(self):
...@@ -243,6 +273,9 @@ class IResnet100(PyTorchModel): ...@@ -243,6 +273,9 @@ class IResnet100(PyTorchModel):
model = imp.load_source("module", self.config).iresnet100(self.checkpoint_path) model = imp.load_source("module", self.config).iresnet100(self.checkpoint_path)
self.model = model self.model = model
self.model.eval()
self.place_model_on_device()
def iresnet_template(embedding, annotation_type, fixed_positions=None): def iresnet_template(embedding, annotation_type, fixed_positions=None):
# DEFINE CROPPING # DEFINE CROPPING
...@@ -272,7 +305,7 @@ def iresnet_template(embedding, annotation_type, fixed_positions=None): ...@@ -272,7 +305,7 @@ def iresnet_template(embedding, annotation_type, fixed_positions=None):
def iresnet34(annotation_type, fixed_positions=None, memory_demanding=False): def iresnet34(annotation_type, fixed_positions=None, memory_demanding=False):
""" """
Get the Resnet34 pipeline which will crop the face :math:`112 \times 112` and Get the Resnet34 pipeline which will crop the face :math:`112 \times 112` and
use the :py:class:`IResnet34` to extract the features use the :py:class:`IResnet34` to extract the features
...@@ -302,7 +335,7 @@ def iresnet34(annotation_type, fixed_positions=None, memory_demanding=False): ...@@ -302,7 +335,7 @@ def iresnet34(annotation_type, fixed_positions=None, memory_demanding=False):
def iresnet50(annotation_type, fixed_positions=None, memory_demanding=False): def iresnet50(annotation_type, fixed_positions=None, memory_demanding=False):
""" """
Get the Resnet50 pipeline which will crop the face :math:`112 \times 112` and Get the Resnet50 pipeline which will crop the face :math:`112 \times 112` and
use the :py:class:`IResnet50` to extract the features use the :py:class:`IResnet50` to extract the features
...@@ -332,7 +365,7 @@ def iresnet50(annotation_type, fixed_positions=None, memory_demanding=False): ...@@ -332,7 +365,7 @@ def iresnet50(annotation_type, fixed_positions=None, memory_demanding=False):
def iresnet100(annotation_type, fixed_positions=None, memory_demanding=False): def iresnet100(annotation_type, fixed_positions=None, memory_demanding=False):
""" """
Get the Resnet100 pipeline which will crop the face :math:`112 \times 112` and Get the Resnet100 pipeline which will crop the face :math:`112 \times 112` and
use the :py:class:`IResnet100` to extract the features use the :py:class:`IResnet100` to extract the features
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment