diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py index 157c0def30ece6b451796f480306906f821526ad..fed15da90ea07dfff467b338e193da49378eeec5 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py @@ -38,16 +38,24 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): """ - def __init__(self, biometric_algorithm, base_dir, force=False, **kwargs): + def __init__(self, biometric_algorithm, base_dir, group=None, force=False, **kwargs): super().__init__(**kwargs) - self.biometric_reference_dir = os.path.join(base_dir, "biometric_references") - self.score_dir = os.path.join(base_dir, "scores") + self.base_dir = base_dir + self.set_score_references_path(group) + self.biometric_algorithm = biometric_algorithm self.force = force self._biometric_reference_extension = ".hdf5" - self._score_extension = ".pkl" - self.base_dir = base_dir + self._score_extension = ".pkl" + + def set_score_references_path(self, group): + if group is None: + self.biometric_reference_dir = os.path.join(self.base_dir, "biometric_references") + self.score_dir = os.path.join(self.base_dir, "scores") + else: + self.biometric_reference_dir = os.path.join(self.base_dir, group, "biometric_references") + self.score_dir = os.path.join(self.base_dir, group, "scores") def enroll(self, enroll_features): return self.biometric_algorithm.enroll(enroll_features) @@ -190,6 +198,9 @@ class BioAlgorithmDaskWrapper(BioAlgorithm): biometric_references, data ) + def set_score_references_path(self, group): + self.biometric_algorithm.set_score_references_path(group) + def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): """