From e03b7e83ff5e18a837c2dc1cec53d196443cbb68 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Wed, 27 May 2020 13:41:54 +0200 Subject: [PATCH] Segmenting the scores and biometric references checkpoints by group to avoid too many files in one directory --- .../pipelines/vanilla_biometrics/wrappers.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py index 157c0def..fed15da9 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py @@ -38,16 +38,24 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): """ - def __init__(self, biometric_algorithm, base_dir, force=False, **kwargs): + def __init__(self, biometric_algorithm, base_dir, group=None, force=False, **kwargs): super().__init__(**kwargs) - self.biometric_reference_dir = os.path.join(base_dir, "biometric_references") - self.score_dir = os.path.join(base_dir, "scores") + self.base_dir = base_dir + self.set_score_references_path(group) + self.biometric_algorithm = biometric_algorithm self.force = force self._biometric_reference_extension = ".hdf5" - self._score_extension = ".pkl" - self.base_dir = base_dir + self._score_extension = ".pkl" + + def set_score_references_path(self, group): + if group is None: + self.biometric_reference_dir = os.path.join(self.base_dir, "biometric_references") + self.score_dir = os.path.join(self.base_dir, "scores") + else: + self.biometric_reference_dir = os.path.join(self.base_dir, group, "biometric_references") + self.score_dir = os.path.join(self.base_dir, group, "scores") def enroll(self, enroll_features): return self.biometric_algorithm.enroll(enroll_features) @@ -190,6 +198,9 @@ class BioAlgorithmDaskWrapper(BioAlgorithm): biometric_references, data ) + def set_score_references_path(self, group): + self.biometric_algorithm.set_score_references_path(group) + def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): """ -- GitLab