From 6cca593f8c2750dd0dc112a562ad9f14c7ab8b77 Mon Sep 17 00:00:00 2001
From: Laurent COLBOIS <lcolbois@.idiap.ch>
Date: Fri, 4 Dec 2020 15:20:10 +0100
Subject: [PATCH] Added TF1/TF2 embedding comparison tests for all Inception
 networks. Fixed inconsistencies in input data scaling.

---
 bob/bio/face/test/test_embeddings.py | 30 +++++++++++++++++++++++-----
 1 file changed, 25 insertions(+), 5 deletions(-)

diff --git a/bob/bio/face/test/test_embeddings.py b/bob/bio/face/test/test_embeddings.py
index 4b0f149d..d4228bef 100644
--- a/bob/bio/face/test/test_embeddings.py
+++ b/bob/bio/face/test/test_embeddings.py
@@ -14,7 +14,7 @@ def test_idiap_inceptionv2_msceleb():
 
     reference = bob.io.base.load(
         pkg_resources.resource_filename(
-            "bob.bio.face.test", "data/inception_resnet_v2_rgb.hdf5"
+            "bob.bio.face.test", "data/inception_resnet_v2_msceleb_rgb.hdf5"
         )
     )
     np.random.seed(10)
@@ -34,11 +34,18 @@ def test_idiap_inceptionv2_msceleb():
 
 @is_library_available("tensorflow")
 def test_idiap_inceptionv2_casia():
-    from bob.bio.face.embeddings import InceptionResnetv2_Casia_CenterLoss_2018
+    from bob.bio.face.embeddings.tf2_inception_resnet import (
+        InceptionResnetv2_Casia_CenterLoss_2018,
+    )
 
+    reference = bob.io.base.load(
+        pkg_resources.resource_filename(
+            "bob.bio.face.test", "data/inception_resnet_v2_casia_rgb.hdf5"
+        )
+    )
     np.random.seed(10)
     transformer = InceptionResnetv2_Casia_CenterLoss_2018()
-    data = np.random.rand(3, 160, 160).astype("uint8")
+    data = (np.random.rand(3, 160, 160) * 255).astype("uint8")
     output = transformer.transform([data])[0]
     assert output.size == 128, output.shape
 
@@ -47,6 +54,7 @@ def test_idiap_inceptionv2_casia():
     transformer_sample = wrap(["sample"], transformer)
     output = [s.data for s in transformer_sample.transform([sample])][0]
 
+    np.testing.assert_allclose(output, reference.flatten(), rtol=1e-5, atol=1e-4)
     assert output.size == 128, output.shape
 
 
@@ -56,9 +64,14 @@ def test_idiap_inceptionv1_msceleb():
         InceptionResnetv1_MsCeleb_CenterLoss_2018,
     )
 
+    reference = bob.io.base.load(
+        pkg_resources.resource_filename(
+            "bob.bio.face.test", "data/inception_resnet_v1_msceleb_rgb.hdf5"
+        )
+    )
     np.random.seed(10)
     transformer = InceptionResnetv1_MsCeleb_CenterLoss_2018()
-    data = np.random.rand(3, 160, 160).astype("uint8")
+    data = (np.random.rand(3, 160, 160) * 255).astype("uint8")
     output = transformer.transform([data])[0]
     assert output.size == 128, output.shape
 
@@ -67,6 +80,7 @@ def test_idiap_inceptionv1_msceleb():
     transformer_sample = wrap(["sample"], transformer)
     output = [s.data for s in transformer_sample.transform([sample])][0]
 
+    np.testing.assert_allclose(output, reference.flatten(), rtol=1e-5, atol=1e-4)
     assert output.size == 128, output.shape
 
 
@@ -76,9 +90,14 @@ def test_idiap_inceptionv1_casia():
         InceptionResnetv1_Casia_CenterLoss_2018,
     )
 
+    reference = bob.io.base.load(
+        pkg_resources.resource_filename(
+            "bob.bio.face.test", "data/inception_resnet_v1_casia_rgb.hdf5"
+        )
+    )
     np.random.seed(10)
     transformer = InceptionResnetv1_Casia_CenterLoss_2018()
-    data = np.random.rand(3, 160, 160).astype("uint8")
+    data = (np.random.rand(3, 160, 160) * 255).astype("uint8")
     output = transformer.transform([data])[0]
     assert output.size == 128, output.shape
 
@@ -87,6 +106,7 @@ def test_idiap_inceptionv1_casia():
     transformer_sample = wrap(["sample"], transformer)
     output = [s.data for s in transformer_sample.transform([sample])][0]
 
+    np.testing.assert_allclose(output, reference.flatten(), rtol=1e-5, atol=1e-4)
     assert output.size == 128, output.shape
 
 
-- 
GitLab