From e03b7e83ff5e18a837c2dc1cec53d196443cbb68 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Wed, 27 May 2020 13:41:54 +0200
Subject: [PATCH] Segmenting the scores and biometric references checkpoints by
 group to avoid too many files in one directory

---
 .../pipelines/vanilla_biometrics/wrappers.py  | 21 ++++++++++++++-----
 1 file changed, 16 insertions(+), 5 deletions(-)

diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
index 157c0def..fed15da9 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
@@ -38,16 +38,24 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
 
     """
 
-    def __init__(self, biometric_algorithm, base_dir, force=False, **kwargs):
+    def __init__(self, biometric_algorithm, base_dir, group=None, force=False, **kwargs):
         super().__init__(**kwargs)
 
-        self.biometric_reference_dir = os.path.join(base_dir, "biometric_references")
-        self.score_dir = os.path.join(base_dir, "scores")
+        self.base_dir = base_dir
+        self.set_score_references_path(group)
+
         self.biometric_algorithm = biometric_algorithm
         self.force = force
         self._biometric_reference_extension = ".hdf5"
-        self._score_extension = ".pkl"
-        self.base_dir = base_dir
+        self._score_extension = ".pkl"        
+
+    def set_score_references_path(self, group):
+        if group is None:
+            self.biometric_reference_dir = os.path.join(self.base_dir, "biometric_references")
+            self.score_dir = os.path.join(self.base_dir, "scores")
+        else:
+            self.biometric_reference_dir = os.path.join(self.base_dir, group, "biometric_references")
+            self.score_dir = os.path.join(self.base_dir, group, "scores")
 
     def enroll(self, enroll_features):
         return self.biometric_algorithm.enroll(enroll_features)
@@ -190,6 +198,9 @@ class BioAlgorithmDaskWrapper(BioAlgorithm):
             biometric_references, data
         )
 
+    def set_score_references_path(self, group):
+        self.biometric_algorithm.set_score_references_path(group)
+
 
 def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None):
     """
-- 
GitLab