Commit f453965e authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Harmonized interfaces

parent 9ad9020d
......@@ -24,6 +24,9 @@ class MxNetTransformer(TransformerMixin, BaseEstimator):
config : str
json file containing the DNN spec
preprocessor:
A function that will transform the data right before forward. The default transformation is `X=X`
use_gpu: bool
"""
......@@ -33,6 +36,7 @@ class MxNetTransformer(TransformerMixin, BaseEstimator):
config=None,
use_gpu=False,
memory_demanding=False,
preprocessor=lambda x: x,
**kwargs,
):
super().__init__(**kwargs)
......@@ -41,6 +45,7 @@ class MxNetTransformer(TransformerMixin, BaseEstimator):
self.use_gpu = use_gpu
self.model = None
self.memory_demanding = memory_demanding
self.preprocessor = preprocessor
def _load_model(self):
import mxnet as mx
......@@ -65,6 +70,7 @@ class MxNetTransformer(TransformerMixin, BaseEstimator):
self._load_model()
X = check_array(X, allow_nd=True)
X = self.preprocessor(X)
def _transform(X):
X = mx.nd.array(X)
......
......@@ -13,13 +13,35 @@ from bob.extension.download import get_file
class PyTorchModel(TransformerMixin, BaseEstimator):
"""
Base Transformer using pytorch models
Parameters
----------
checkpoint_path: str
Path containing the checkpoint
config:
Path containing some configuration file (e.g. .json, .prototxt)
preprocessor:
A function that will transform the data right before forward. The default transformation is `X/255`
"""
def __init__(self, checkpoint_path=None, config=None, **kwargs):
def __init__(
self,
checkpoint_path=None,
config=None,
preprocessor=lambda x: x / 255,
**kwargs
):
super().__init__(**kwargs)
self.checkpoint_path = checkpoint_path
self.config = config
self.model = None
self.preprocessor = preprocessor
def transform(self, X):
"""__call__(image) -> feature
......@@ -42,7 +64,7 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
self._load_model()
X = check_array(X, allow_nd=True)
X = torch.Tensor(X)
X = X / 255
X = self.preprocessor(X)
return self.model(X).detach().numpy()
......
......@@ -33,7 +33,7 @@ class TensorflowTransformer(TransformerMixin, BaseEstimator):
Path containing the checkpoint
preprocessor:
Preprocessor function
A function that will transform the data right before forward
memory_demanding bool
If `True`, the `transform` method will run one sample at the time.
......
......@@ -18,11 +18,11 @@ from bob.bio.base.test.utils import is_library_available
images = dict()
images["bioref"] = (
pkg_resources.resource_filename("bob.bio.face.test", "data/testimage.jpg"),
{"reye": (131, 176), "leye": (222, 170)},
{"reye": (176, 131), "leye": (170, 222)},
)
images["probe"] = (
pkg_resources.resource_filename("bob.bio.face.test", "data/ada.png"),
{"reye": (440, 207), "leye": (546, 207)},
{"reye": (207, 440), "leye": (207, 546)},
)
......@@ -81,7 +81,7 @@ def run_baseline(baseline, samples_for_training=[], target_scores=None):
assert len(checkpoint_scores[0]) == 1
if target_scores is not None:
np.allclose(target_scores, scores[0][0].data, atol=10e-3, rtol=10e-3)
assert np.allclose(target_scores, scores[0][0].data, atol=10e-3, rtol=10e-3)
assert np.isclose(scores[0][0].data, checkpoint_scores[0][0].data)
......@@ -114,75 +114,68 @@ def run_baseline(baseline, samples_for_training=[], target_scores=None):
@pytest.mark.slow
@is_library_available("tensorflow")
def test_facenet_baseline():
run_baseline("facenet-sanderberg", target_scores=[-0.9220775737526933])
run_baseline("facenet-sanderberg", target_scores=-0.9220775737526933)
@pytest.mark.slow
@is_library_available("tensorflow")
def test_inception_resnetv2_msceleb():
run_baseline("inception-resnetv2-msceleb", target_scores=[-0.43447269718504244])
run_baseline("inception-resnetv2-msceleb", target_scores=-0.43447269718504244)
@pytest.mark.slow
@is_library_available("tensorflow")
def test_inception_resnetv2_casiawebface():
run_baseline("inception-resnetv2-casiawebface", target_scores=[-0.634583944368043])
run_baseline("inception-resnetv2-casiawebface", target_scores=-0.634583944368043)
@pytest.mark.slow
@is_library_available("tensorflow")
def test_inception_resnetv1_msceleb():
run_baseline("inception-resnetv1-msceleb", target_scores=[-0.44497649298306907])
run_baseline("inception-resnetv1-msceleb", target_scores=-0.44497649298306907)
@pytest.mark.slow
@is_library_available("tensorflow")
def test_inception_resnetv1_casiawebface():
run_baseline("inception-resnetv1-casiawebface", target_scores=[-0.6411599976437636])
run_baseline("inception-resnetv1-casiawebface", target_scores=-0.6411599976437636)
@pytest.mark.slow
@is_library_available("mxnet")
def test_arcface_insightface():
run_baseline("arcface-insightface", target_scores=[-0.0005965275677296544])
run_baseline("arcface-insightface", target_scores=-0.0005965275677296544)
def test_gabor_graph():
run_baseline("gabor_graph", target_scores=[0.4385451147418939])
# def test_lda():
# run_baseline("lda", get_fake_samples_for_training())
@pytest.mark.slow
@is_library_available("tensorflow")
def test_arcface_resnet50_msceleb_v1():
run_baseline("resnet50-msceleb-arcface-2021", target_scores=-0.0008105830382632018)
@pytest.mark.slow
@is_library_available("opencv-python")
def test_opencv_pipe():
run_baseline("opencv-pipe", target_scores=None)
@is_library_available("tensorflow")
def test_arcface_resnet50_vgg2_v1():
run_baseline("resnet50-vgg2-arcface-2021", target_scores=-0.0035127080413503986)
# @pytest.mark.slow
# @is_library_available("mxnet")
# def test_mxnet_pipe():
# run_baseline("mxnet-pipe", target_scores=None)
def test_gabor_graph():
run_baseline("gabor_graph", target_scores=0.4385451147418939)
# @pytest.mark.slow
# @is_library_available("tensorflow")
# def test_tf_pipe():
# run_baseline("tf-pipe", target_scores=None)
# def test_lda():
# run_baseline("lda", get_fake_samples_for_training())
@pytest.mark.slow
@is_library_available("torch")
def test_afffe():
run_baseline("afffe", target_scores=-0.7397219061544165)
run_baseline(
"afffe", target_scores=-1.0274936425058916,
)
@pytest.mark.slow
@is_library_available("cv2")
def test_vgg16_oxford():
import ipdb
ipdb.set_trace()
run_baseline("vgg16-oxford", target_scores=None)
run_baseline("vgg16-oxford", target_scores=-0.9911880900309596)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment