Commit b4cd8c9b authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created a mechanism that clears biometric reference caches

parent 54b3600e
Pipeline #45953 passed with stage
in 5 minutes and 56 seconds
......@@ -35,6 +35,12 @@ class BioAlgorithm(metaclass=ABCMeta):
self.stacked_biometric_references = None
self.score_reduction_operation = average_scores
def clear_caches(self):
"""
Clean all cached objects from BioAlgorithm
"""
self.stacked_biometric_references = None
def enroll_samples(self, biometric_references):
"""This method should implement the enrollment sub-pipeline of the Vanilla Biometrics Pipeline. TODO REF
......@@ -122,7 +128,7 @@ class BioAlgorithm(metaclass=ABCMeta):
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
)
)
self.clear_caches()
return retval
def _score_sample_set(
......
......@@ -65,6 +65,9 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
self._biometric_reference_extension = ".hdf5"
self._score_extension = ".joblib"
def clear_caches(self):
self.biometric_algorithm.clear_caches()
def set_score_references_path(self, group):
if group is None:
self.biometric_reference_dir = os.path.join(
......@@ -188,6 +191,9 @@ class BioAlgorithmDaskWrapper(BioAlgorithm):
def __init__(self, biometric_algorithm, **kwargs):
self.biometric_algorithm = biometric_algorithm
def clear_caches(self):
self.biometric_algorithm.clear_caches()
def enroll_samples(self, biometric_reference_features):
biometric_references = biometric_reference_features.map_partitions(
......
......@@ -127,6 +127,7 @@ class ZTNormPipeline(object):
biometric_references,
allow_scoring_with_all_biometric_references,
)
if self.t_norm:
if t_biometric_reference_samples is None:
raise ValueError("No samples for `t_norm` was provided")
......@@ -141,6 +142,7 @@ class ZTNormPipeline(object):
raw_scores,
allow_scoring_with_all_biometric_references,
)
if not self.z_norm:
# In case z_norm=False and t_norm=True
return t_normed_scores
......
......@@ -127,11 +127,16 @@ def zt_norm_stubs(references, probes, t_references, z_probes):
zt_normed_scores = _norm(z_normed_scores, z_t_scores, axis=0)
assert zt_normed_scores.shape == (n_reference, n_probes)
s_normed_scores = (z_normed_scores+t_normed_scores)*0.5
s_normed_scores = (z_normed_scores + t_normed_scores) * 0.5
assert s_normed_scores.shape == (n_reference, n_probes)
return raw_scores, z_normed_scores, t_normed_scores, zt_normed_scores, s_normed_scores
return (
raw_scores,
z_normed_scores,
t_normed_scores,
zt_normed_scores,
s_normed_scores,
)
def test_norm_mechanics():
......@@ -285,7 +290,6 @@ def test_norm_mechanics():
)
assert np.allclose(z_normed_scores, z_normed_scores_ref)
############
# TESTING T-NORM
#############
......@@ -319,7 +323,7 @@ def test_norm_mechanics():
t_normed_scores = _dump_scores_from_samples(
t_normed_score_samples, shape=(n_probes, n_references)
)
)
assert np.allclose(t_normed_scores, t_normed_scores_ref)
############
......@@ -371,7 +375,6 @@ def test_norm_mechanics():
scheduler="single-threaded"
)
raw_scores = _dump_scores_from_samples(
raw_score_samples, shape=(n_probes, n_references)
)
......@@ -397,9 +400,6 @@ def test_norm_mechanics():
)
assert np.allclose(s_normed_scores, s_normed_scores_ref)
# No dask
run(False) # On memory
......@@ -438,7 +438,13 @@ def test_znorm_on_memory():
vanilla_biometrics_pipeline, npartitions=2
)
raw_scores, z_scores, t_scores, zt_scores, s_scores = vanilla_biometrics_pipeline(
(
raw_scores,
z_scores,
t_scores,
zt_scores,
s_scores,
) = vanilla_biometrics_pipeline(
database.background_model_samples(),
database.references(),
database.probes(),
......@@ -447,10 +453,6 @@ def test_znorm_on_memory():
allow_scoring_with_all_biometric_references=database.allow_scoring_with_all_biometric_references,
)
# if vanilla_biometrics_pipeline.score_writer is not None:
# concatenated_scores
# pass
def _concatenate(pipeline, scores, path):
writed_scores = pipeline.write_scores(scores)
concatenated_scores = pipeline.post_process(
......@@ -484,7 +486,6 @@ def test_znorm_on_memory():
zt_scores = zt_scores.compute(scheduler="single-threaded")
s_scores = s_scores.compute(scheduler="single-threaded")
if isinstance(score_writer, CSVScoreWriter):
n_lines = 51 if with_dask else 101
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment