Commit 32589f24 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented the hash_fn mechanism into the biometric algorithm

parent c746d5f0
Pipeline #53151 passed with stage
in 17 minutes and 49 seconds
...@@ -49,6 +49,14 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): ...@@ -49,6 +49,14 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
force: bool force: bool
If True, will recompute scores and biometric references no matter if a file exists If True, will recompute scores and biometric references no matter if a file exists
hash_fn
Pointer to a hash function. This hash function maps
`sample.key` to a hash code and this hash code corresponds a relative directory
where a single `sample` will be checkpointed.
This is useful when is desirable file directories with less than
a certain number of files.
Examples Examples
-------- --------
...@@ -59,7 +67,13 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): ...@@ -59,7 +67,13 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
""" """
def __init__( def __init__(
self, biometric_algorithm, base_dir, group=None, force=False, **kwargs self,
biometric_algorithm,
base_dir,
group=None,
force=False,
hash_fn=None,
**kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -70,6 +84,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): ...@@ -70,6 +84,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
self.force = force self.force = force
self._biometric_reference_extension = ".hdf5" self._biometric_reference_extension = ".hdf5"
self._score_extension = ".pickle.gz" self._score_extension = ".pickle.gz"
self.hash_fn = hash_fn
def clear_caches(self): def clear_caches(self):
self.biometric_algorithm.clear_caches() self.biometric_algorithm.clear_caches()
...@@ -109,10 +124,16 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): ...@@ -109,10 +124,16 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
""" """
# Amending `models` directory # Amending `models` directory
hash_dir_name = (
self.hash_fn(str(sampleset.key)) if self.hash_fn is not None else ""
)
path = os.path.join( path = os.path.join(
self.biometric_reference_dir, self.biometric_reference_dir,
hash_dir_name,
str(sampleset.key) + self._biometric_reference_extension, str(sampleset.key) + self._biometric_reference_extension,
) )
if self.force or not os.path.exists(path): if self.force or not os.path.exists(path):
enrolled_sample = self.biometric_algorithm._enroll_sample_set(sampleset) enrolled_sample = self.biometric_algorithm._enroll_sample_set(sampleset)
...@@ -147,8 +168,14 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): ...@@ -147,8 +168,14 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
suffix = "_".join([str(s.key) for s in biometric_references[0:3]]) suffix = "_".join([str(s.key) for s in biometric_references[0:3]])
return os.path.join(reference_id, name + suffix) return os.path.join(reference_id, name + suffix)
# Amending `models` directory
hash_dir_name = (
self.hash_fn(str(sampleset.key)) if self.hash_fn is not None else ""
)
path = os.path.join( path = os.path.join(
self.score_dir, self.score_dir,
hash_dir_name,
_make_name(sampleset, biometric_references) + self._score_extension, _make_name(sampleset, biometric_references) + self._score_extension,
) )
...@@ -333,7 +360,7 @@ def checkpoint_vanilla_biometrics( ...@@ -333,7 +360,7 @@ def checkpoint_vanilla_biometrics(
pipeline.biometric_algorithm.base_dir = bio_ref_scores_dir pipeline.biometric_algorithm.base_dir = bio_ref_scores_dir
else: else:
pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper( pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper(
pipeline.biometric_algorithm, base_dir=bio_ref_scores_dir pipeline.biometric_algorithm, base_dir=bio_ref_scores_dir, hash_fn=hash_fn
) )
return pipeline return pipeline
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment