Skip to content
Snippets Groups Projects
Commit 6cca593f authored by Laurent COLBOIS's avatar Laurent COLBOIS
Browse files

Added TF1/TF2 embedding comparison tests for all Inception networks. Fixed...

Added TF1/TF2 embedding comparison tests for all Inception networks. Fixed inconsistencies in input data scaling.
parent 62986b36
No related branches found
No related tags found
1 merge request!84Test Facenet
Pipeline #46317 passed
......@@ -14,7 +14,7 @@ def test_idiap_inceptionv2_msceleb():
reference = bob.io.base.load(
pkg_resources.resource_filename(
"bob.bio.face.test", "data/inception_resnet_v2_rgb.hdf5"
"bob.bio.face.test", "data/inception_resnet_v2_msceleb_rgb.hdf5"
)
)
np.random.seed(10)
......@@ -34,11 +34,18 @@ def test_idiap_inceptionv2_msceleb():
@is_library_available("tensorflow")
def test_idiap_inceptionv2_casia():
from bob.bio.face.embeddings import InceptionResnetv2_Casia_CenterLoss_2018
from bob.bio.face.embeddings.tf2_inception_resnet import (
InceptionResnetv2_Casia_CenterLoss_2018,
)
reference = bob.io.base.load(
pkg_resources.resource_filename(
"bob.bio.face.test", "data/inception_resnet_v2_casia_rgb.hdf5"
)
)
np.random.seed(10)
transformer = InceptionResnetv2_Casia_CenterLoss_2018()
data = np.random.rand(3, 160, 160).astype("uint8")
data = (np.random.rand(3, 160, 160) * 255).astype("uint8")
output = transformer.transform([data])[0]
assert output.size == 128, output.shape
......@@ -47,6 +54,7 @@ def test_idiap_inceptionv2_casia():
transformer_sample = wrap(["sample"], transformer)
output = [s.data for s in transformer_sample.transform([sample])][0]
np.testing.assert_allclose(output, reference.flatten(), rtol=1e-5, atol=1e-4)
assert output.size == 128, output.shape
......@@ -56,9 +64,14 @@ def test_idiap_inceptionv1_msceleb():
InceptionResnetv1_MsCeleb_CenterLoss_2018,
)
reference = bob.io.base.load(
pkg_resources.resource_filename(
"bob.bio.face.test", "data/inception_resnet_v1_msceleb_rgb.hdf5"
)
)
np.random.seed(10)
transformer = InceptionResnetv1_MsCeleb_CenterLoss_2018()
data = np.random.rand(3, 160, 160).astype("uint8")
data = (np.random.rand(3, 160, 160) * 255).astype("uint8")
output = transformer.transform([data])[0]
assert output.size == 128, output.shape
......@@ -67,6 +80,7 @@ def test_idiap_inceptionv1_msceleb():
transformer_sample = wrap(["sample"], transformer)
output = [s.data for s in transformer_sample.transform([sample])][0]
np.testing.assert_allclose(output, reference.flatten(), rtol=1e-5, atol=1e-4)
assert output.size == 128, output.shape
......@@ -76,9 +90,14 @@ def test_idiap_inceptionv1_casia():
InceptionResnetv1_Casia_CenterLoss_2018,
)
reference = bob.io.base.load(
pkg_resources.resource_filename(
"bob.bio.face.test", "data/inception_resnet_v1_casia_rgb.hdf5"
)
)
np.random.seed(10)
transformer = InceptionResnetv1_Casia_CenterLoss_2018()
data = np.random.rand(3, 160, 160).astype("uint8")
data = (np.random.rand(3, 160, 160) * 255).astype("uint8")
output = transformer.transform([data])[0]
assert output.size == 128, output.shape
......@@ -87,6 +106,7 @@ def test_idiap_inceptionv1_casia():
transformer_sample = wrap(["sample"], transformer)
output = [s.data for s in transformer_sample.transform([sample])][0]
np.testing.assert_allclose(output, reference.flatten(), rtol=1e-5, atol=1e-4)
assert output.size == 128, output.shape
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment