Commit 32589f24 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented the hash_fn mechanism into the biometric algorithm

parent c746d5f0
Pipeline #53151 passed with stage
in 17 minutes and 49 seconds
......@@ -49,6 +49,14 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
force: bool
If True, will recompute scores and biometric references no matter if a file exists
hash_fn
Pointer to a hash function. This hash function maps
`sample.key` to a hash code and this hash code corresponds a relative directory
where a single `sample` will be checkpointed.
This is useful when is desirable file directories with less than
a certain number of files.
Examples
--------
......@@ -59,7 +67,13 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
"""
def __init__(
self, biometric_algorithm, base_dir, group=None, force=False, **kwargs
self,
biometric_algorithm,
base_dir,
group=None,
force=False,
hash_fn=None,
**kwargs
):
super().__init__(**kwargs)
......@@ -70,6 +84,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
self.force = force
self._biometric_reference_extension = ".hdf5"
self._score_extension = ".pickle.gz"
self.hash_fn = hash_fn
def clear_caches(self):
self.biometric_algorithm.clear_caches()
......@@ -109,10 +124,16 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
"""
# Amending `models` directory
hash_dir_name = (
self.hash_fn(str(sampleset.key)) if self.hash_fn is not None else ""
)
path = os.path.join(
self.biometric_reference_dir,
hash_dir_name,
str(sampleset.key) + self._biometric_reference_extension,
)
if self.force or not os.path.exists(path):
enrolled_sample = self.biometric_algorithm._enroll_sample_set(sampleset)
......@@ -147,8 +168,14 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
suffix = "_".join([str(s.key) for s in biometric_references[0:3]])
return os.path.join(reference_id, name + suffix)
# Amending `models` directory
hash_dir_name = (
self.hash_fn(str(sampleset.key)) if self.hash_fn is not None else ""
)
path = os.path.join(
self.score_dir,
hash_dir_name,
_make_name(sampleset, biometric_references) + self._score_extension,
)
......@@ -333,7 +360,7 @@ def checkpoint_vanilla_biometrics(
pipeline.biometric_algorithm.base_dir = bio_ref_scores_dir
else:
pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper(
pipeline.biometric_algorithm, base_dir=bio_ref_scores_dir
pipeline.biometric_algorithm, base_dir=bio_ref_scores_dir, hash_fn=hash_fn
)
return pipeline
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment