From 9100d0db48c2c863b1032b75fac2c5fb07bfc6ae Mon Sep 17 00:00:00 2001
From: Laurent COLBOIS <lcolbois@.idiap.ch>
Date: Thu, 6 May 2021 09:16:10 +0200
Subject: [PATCH] Refactor reading from the database configuration

---
 .../config/baseline/arcface_insightface.py     | 17 +++++------------
 bob/bio/face/config/baseline/dummy.py          | 18 ++++++------------
 .../face/config/baseline/facenet_sanderberg.py | 17 +++++------------
 bob/bio/face/config/baseline/gabor_graph.py    |  9 ++-------
 bob/bio/face/config/baseline/helpers.py        | 18 ++++++++++++++++++
 .../inception_resnetv1_casiawebface.py         | 17 +++++------------
 .../baseline/inception_resnetv1_msceleb.py     | 17 +++++------------
 .../inception_resnetv2_casiawebface.py         | 17 +++++------------
 .../baseline/inception_resnetv2_msceleb.py     | 17 +++++------------
 bob/bio/face/config/baseline/lda.py            |  9 ++-------
 bob/bio/face/config/baseline/lgbphs.py         |  9 ++-------
 .../mobilenetv2_msceleb_arcface_2021.py        | 17 +++++------------
 .../baseline/resnet50_msceleb_arcface_2021.py  | 17 +++++------------
 .../baseline/resnet50_vgg2_arcface_2021.py     | 17 +++++------------
 .../config/baseline/tf2_inception_resnet.py    | 13 ++-----------
 15 files changed, 77 insertions(+), 152 deletions(-)

diff --git a/bob/bio/face/config/baseline/arcface_insightface.py b/bob/bio/face/config/baseline/arcface_insightface.py
index 3aae8949..b2de000f 100644
--- a/bob/bio/face/config/baseline/arcface_insightface.py
+++ b/bob/bio/face/config/baseline/arcface_insightface.py
@@ -1,21 +1,14 @@
 from bob.bio.face.embeddings.mxnet_models import ArcFaceInsightFace
-from bob.bio.face.config.baseline.helpers import embedding_transformer_112x112
+from bob.bio.face.config.baseline.helpers import (
+    embedding_transformer_112x112,
+    lookup_config_from_database,
+)
 from bob.bio.base.pipelines.vanilla_biometrics import (
     Distance,
     VanillaBiometricsPipeline,
 )
 
-
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-    memory_demanding = (
-        database.memory_demanding if hasattr(database, "memory_demanding") else False
-    )
-else:
-    annotation_type = None
-    fixed_positions = None
-    memory_demanding = False
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def load(annotation_type, fixed_positions=None):
diff --git a/bob/bio/face/config/baseline/dummy.py b/bob/bio/face/config/baseline/dummy.py
index 7e0317a3..eaa993d3 100644
--- a/bob/bio/face/config/baseline/dummy.py
+++ b/bob/bio/face/config/baseline/dummy.py
@@ -5,21 +5,17 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
     VanillaBiometricsPipeline,
 )
 from bob.pipelines.transformers import SampleLinearize
+from bob.bio.face.config.baseline.helpers import lookup_config_from_database
 
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-else:
-    annotation_type = None
-    fixed_positions = None
-
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 import bob.ip.color
 from sklearn.base import TransformerMixin, BaseEstimator
-class ToGray(TransformerMixin, BaseEstimator):
 
+
+class ToGray(TransformerMixin, BaseEstimator):
     def transform(self, X, annotations=None):
-        return [bob.ip.color.rgb_to_gray(data)[0:10,0:10] for data in X]
+        return [bob.ip.color.rgb_to_gray(data)[0:10, 0:10] for data in X]
 
     def _more_tags(self):
         return {"stateless": True, "requires_fit": False}
@@ -34,9 +30,7 @@ def load(annotation_type, fixed_positions=None):
 
     transformer = make_pipeline(
         wrap(
-            ["sample"],
-            ToGray(),
-            transform_extra_arguments=transform_extra_arguments,
+            ["sample"], ToGray(), transform_extra_arguments=transform_extra_arguments,
         ),
         SampleLinearize(),
     )
diff --git a/bob/bio/face/config/baseline/facenet_sanderberg.py b/bob/bio/face/config/baseline/facenet_sanderberg.py
index e3d1dc26..20e3651a 100644
--- a/bob/bio/face/config/baseline/facenet_sanderberg.py
+++ b/bob/bio/face/config/baseline/facenet_sanderberg.py
@@ -1,23 +1,16 @@
 from bob.bio.face.embeddings.tf2_inception_resnet import (
     FaceNetSanderberg_20170512_110547,
 )
-from bob.bio.face.config.baseline.helpers import embedding_transformer_160x160
+from bob.bio.face.config.baseline.helpers import (
+    embedding_transformer_160x160,
+    lookup_config_from_database,
+)
 from bob.bio.base.pipelines.vanilla_biometrics import (
     Distance,
     VanillaBiometricsPipeline,
 )
 
-memory_demanding = False
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-    memory_demanding = (
-        database.memory_demanding if hasattr(database, "memory_demanding") else False
-    )
-
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def load(annotation_type, fixed_positions=None):
diff --git a/bob/bio/face/config/baseline/gabor_graph.py b/bob/bio/face/config/baseline/gabor_graph.py
index 1f1a7860..6c79bffd 100644
--- a/bob/bio/face/config/baseline/gabor_graph.py
+++ b/bob/bio/face/config/baseline/gabor_graph.py
@@ -3,7 +3,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
     VanillaBiometricsPipeline,
     BioAlgorithmLegacy,
 )
-from bob.bio.face.config.baseline.helpers import crop_80x64
+from bob.bio.face.config.baseline.helpers import crop_80x64, lookup_config_from_database
 import math
 import numpy as np
 import bob.bio.face
@@ -17,12 +17,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 #### SOLVING IF THERE'S ANY DATABASE INFORMATION
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def get_cropper(annotation_type, fixed_positions=None):
diff --git a/bob/bio/face/config/baseline/helpers.py b/bob/bio/face/config/baseline/helpers.py
index aa9e7cd4..8c04bcc2 100644
--- a/bob/bio/face/config/baseline/helpers.py
+++ b/bob/bio/face/config/baseline/helpers.py
@@ -9,6 +9,24 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+def lookup_config_from_database():
+    if "database" in locals():
+        annotation_type = database.annotation_type
+        fixed_positions = database.fixed_positions
+        memory_demanding = (
+            database.memory_demanding
+            if hasattr(database, "memory_demanding")
+            else False
+        )
+
+    else:
+        annotation_type = None
+        fixed_positions = None
+        memory_demanding = False
+
+    return annotation_type, fixed_positions, memory_demanding
+
+
 def face_crop_solver(
     cropped_image_size,
     cropped_positions=None,
diff --git a/bob/bio/face/config/baseline/inception_resnetv1_casiawebface.py b/bob/bio/face/config/baseline/inception_resnetv1_casiawebface.py
index acc6f981..90bd0111 100644
--- a/bob/bio/face/config/baseline/inception_resnetv1_casiawebface.py
+++ b/bob/bio/face/config/baseline/inception_resnetv1_casiawebface.py
@@ -1,23 +1,16 @@
 from bob.bio.face.embeddings.tf2_inception_resnet import (
     InceptionResnetv1_Casia_CenterLoss_2018,
 )
-from bob.bio.face.config.baseline.helpers import embedding_transformer_160x160
+from bob.bio.face.config.baseline.helpers import (
+    embedding_transformer_160x160,
+    lookup_config_from_database,
+)
 from bob.bio.base.pipelines.vanilla_biometrics import (
     Distance,
     VanillaBiometricsPipeline,
 )
 
-memory_demanding = False
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-    memory_demanding = (
-        database.memory_demanding if hasattr(database, "memory_demanding") else False
-    )
-
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def load(annotation_type, fixed_positions=None):
diff --git a/bob/bio/face/config/baseline/inception_resnetv1_msceleb.py b/bob/bio/face/config/baseline/inception_resnetv1_msceleb.py
index 70a1a58f..c4597096 100644
--- a/bob/bio/face/config/baseline/inception_resnetv1_msceleb.py
+++ b/bob/bio/face/config/baseline/inception_resnetv1_msceleb.py
@@ -1,23 +1,16 @@
 from bob.bio.face.embeddings.tf2_inception_resnet import (
     InceptionResnetv1_MsCeleb_CenterLoss_2018,
 )
-from bob.bio.face.config.baseline.helpers import embedding_transformer_160x160
+from bob.bio.face.config.baseline.helpers import (
+    embedding_transformer_160x160,
+    lookup_config_from_database,
+)
 from bob.bio.base.pipelines.vanilla_biometrics import (
     Distance,
     VanillaBiometricsPipeline,
 )
 
-memory_demanding = False
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-    memory_demanding = (
-        database.memory_demanding if hasattr(database, "memory_demanding") else False
-    )
-
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def load(annotation_type, fixed_positions=None):
diff --git a/bob/bio/face/config/baseline/inception_resnetv2_casiawebface.py b/bob/bio/face/config/baseline/inception_resnetv2_casiawebface.py
index 3dd27043..cb82dee8 100644
--- a/bob/bio/face/config/baseline/inception_resnetv2_casiawebface.py
+++ b/bob/bio/face/config/baseline/inception_resnetv2_casiawebface.py
@@ -1,23 +1,16 @@
 from bob.bio.face.embeddings.tf2_inception_resnet import (
     InceptionResnetv2_Casia_CenterLoss_2018,
 )
-from bob.bio.face.config.baseline.helpers import embedding_transformer_160x160
+from bob.bio.face.config.baseline.helpers import (
+    embedding_transformer_160x160,
+    lookup_config_from_database,
+)
 from bob.bio.base.pipelines.vanilla_biometrics import (
     Distance,
     VanillaBiometricsPipeline,
 )
 
-memory_demanding = False
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-    memory_demanding = (
-        database.memory_demanding if hasattr(database, "memory_demanding") else False
-    )
-
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def load(annotation_type, fixed_positions=None):
diff --git a/bob/bio/face/config/baseline/inception_resnetv2_msceleb.py b/bob/bio/face/config/baseline/inception_resnetv2_msceleb.py
index ba339bcf..fa502308 100644
--- a/bob/bio/face/config/baseline/inception_resnetv2_msceleb.py
+++ b/bob/bio/face/config/baseline/inception_resnetv2_msceleb.py
@@ -1,23 +1,16 @@
 from bob.bio.face.embeddings.tf2_inception_resnet import (
     InceptionResnetv2_MsCeleb_CenterLoss_2018,
 )
-from bob.bio.face.config.baseline.helpers import embedding_transformer_160x160
+from bob.bio.face.config.baseline.helpers import (
+    embedding_transformer_160x160,
+    lookup_config_from_database,
+)
 from bob.bio.base.pipelines.vanilla_biometrics import (
     Distance,
     VanillaBiometricsPipeline,
 )
 
-memory_demanding = False
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-
-    memory_demanding = (
-        database.memory_demanding if hasattr(database, "memory_demanding") else False
-    )
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def load(annotation_type, fixed_positions=None):
diff --git a/bob/bio/face/config/baseline/lda.py b/bob/bio/face/config/baseline/lda.py
index 1343f51a..062f18f6 100644
--- a/bob/bio/face/config/baseline/lda.py
+++ b/bob/bio/face/config/baseline/lda.py
@@ -3,7 +3,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
     VanillaBiometricsPipeline,
     BioAlgorithmLegacy,
 )
-from bob.bio.face.config.baseline.helpers import crop_80x64
+from bob.bio.face.config.baseline.helpers import crop_80x64, lookup_config_from_database
 import numpy as np
 import bob.bio.face
 from sklearn.pipeline import make_pipeline
@@ -18,12 +18,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 #### SOLVING IF THERE'S ANY DATABASE INFORMATION
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 ####### SOLVING THE FACE CROPPER TO BE USED ##########
diff --git a/bob/bio/face/config/baseline/lgbphs.py b/bob/bio/face/config/baseline/lgbphs.py
index eebcbba5..902f43cf 100644
--- a/bob/bio/face/config/baseline/lgbphs.py
+++ b/bob/bio/face/config/baseline/lgbphs.py
@@ -3,7 +3,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
     VanillaBiometricsPipeline,
     BioAlgorithmLegacy,
 )
-from bob.bio.face.config.baseline.helpers import crop_80x64
+from bob.bio.face.config.baseline.helpers import crop_80x64, lookup_config_from_database
 import math
 import numpy as np
 import bob.bio.face
@@ -13,12 +13,7 @@ import bob.math
 
 
 #### SOLVING IF THERE'S ANY DATABASE INFORMATION
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def get_cropper(annotation_type, fixed_positions=None):
diff --git a/bob/bio/face/config/baseline/mobilenetv2_msceleb_arcface_2021.py b/bob/bio/face/config/baseline/mobilenetv2_msceleb_arcface_2021.py
index e68e7ef3..aae95a2a 100644
--- a/bob/bio/face/config/baseline/mobilenetv2_msceleb_arcface_2021.py
+++ b/bob/bio/face/config/baseline/mobilenetv2_msceleb_arcface_2021.py
@@ -1,21 +1,14 @@
 from bob.bio.face.embeddings.mobilenet_v2 import MobileNetv2_MsCeleb_ArcFace_2021
-from bob.bio.face.config.baseline.helpers import embedding_transformer_112x112
+from bob.bio.face.config.baseline.helpers import (
+    embedding_transformer_112x112,
+    lookup_config_from_database,
+)
 from bob.bio.base.pipelines.vanilla_biometrics import (
     Distance,
     VanillaBiometricsPipeline,
 )
 
-memory_demanding = False
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-
-    memory_demanding = (
-        database.memory_demanding if hasattr(database, "memory_demanding") else False
-    )
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def load(annotation_type, fixed_positions=None):
diff --git a/bob/bio/face/config/baseline/resnet50_msceleb_arcface_2021.py b/bob/bio/face/config/baseline/resnet50_msceleb_arcface_2021.py
index dfeb4b74..f60903ac 100644
--- a/bob/bio/face/config/baseline/resnet50_msceleb_arcface_2021.py
+++ b/bob/bio/face/config/baseline/resnet50_msceleb_arcface_2021.py
@@ -1,21 +1,14 @@
 from bob.bio.face.embeddings.resnet50 import Resnet50_MsCeleb_ArcFace_2021
-from bob.bio.face.config.baseline.helpers import embedding_transformer_112x112
+from bob.bio.face.config.baseline.helpers import (
+    embedding_transformer_112x112,
+    lookup_config_from_database,
+)
 from bob.bio.base.pipelines.vanilla_biometrics import (
     Distance,
     VanillaBiometricsPipeline,
 )
 
-memory_demanding = False
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-
-    memory_demanding = (
-        database.memory_demanding if hasattr(database, "memory_demanding") else False
-    )
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def load(annotation_type, fixed_positions=None):
diff --git a/bob/bio/face/config/baseline/resnet50_vgg2_arcface_2021.py b/bob/bio/face/config/baseline/resnet50_vgg2_arcface_2021.py
index b8f13ec4..0a6941d0 100644
--- a/bob/bio/face/config/baseline/resnet50_vgg2_arcface_2021.py
+++ b/bob/bio/face/config/baseline/resnet50_vgg2_arcface_2021.py
@@ -1,21 +1,14 @@
 from bob.bio.face.embeddings.resnet50 import Resnet50_VGG2_ArcFace_2021
-from bob.bio.face.config.baseline.helpers import embedding_transformer_112x112
+from bob.bio.face.config.baseline.helpers import (
+    embedding_transformer_112x112,
+    lookup_config_from_database,
+)
 from bob.bio.base.pipelines.vanilla_biometrics import (
     Distance,
     VanillaBiometricsPipeline,
 )
 
-memory_demanding = False
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-
-    memory_demanding = (
-        database.memory_demanding if hasattr(database, "memory_demanding") else False
-    )
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def load(annotation_type, fixed_positions=None):
diff --git a/bob/bio/face/config/baseline/tf2_inception_resnet.py b/bob/bio/face/config/baseline/tf2_inception_resnet.py
index 98048315..f71e0a78 100644
--- a/bob/bio/face/config/baseline/tf2_inception_resnet.py
+++ b/bob/bio/face/config/baseline/tf2_inception_resnet.py
@@ -4,6 +4,7 @@ from bob.bio.face.preprocessor import FaceCrop
 from bob.bio.face.config.baseline.helpers import (
     embedding_transformer_default_cropping,
     embedding_transformer,
+    lookup_config_from_database,
 )
 
 from sklearn.pipeline import make_pipeline
@@ -13,17 +14,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
     VanillaBiometricsPipeline,
 )
 
-memory_demanding = False
-if "database" in locals():
-    annotation_type = database.annotation_type
-    fixed_positions = database.fixed_positions
-    memory_demanding = (
-        database.memory_demanding if hasattr(database, "memory_demanding") else False
-    )
-
-else:
-    annotation_type = None
-    fixed_positions = None
+annotation_type, fixed_positions, memory_demanding = lookup_config_from_database()
 
 
 def load(annotation_type, fixed_positions=None):
-- 
GitLab