Skip to content
Snippets Groups Projects
Verified Commit 91dfec12 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

fix: stop T-norm fit from exploding in memory.

Removed the unnecessary copy of `data` and storage into a Sample object.
Now saves the stats in a dict that is picklable.
parent be04cd23
Branches
Tags
1 merge request!326fix: stop T-norm fit from exploding in memory.
Pipeline #75792 failed
......@@ -296,6 +296,11 @@ class TNormScores(TransformerMixin, BaseEstimator):
Parameters
----------
Attributes
----------
t_stats: a dictionary keeping the ``mu`` and ``std`` of each identity.
"""
post_process_template = "enroll"
......@@ -309,23 +314,12 @@ class TNormScores(TransformerMixin, BaseEstimator):
super().__init__(**kwargs)
self.top_norm = top_norm
self.top_norm_score_fraction = top_norm_score_fraction
self.t_stats: dict[str, dict[str, float]] = dict()
def fit(self, t_scores, y=None):
# TODO: THIS IS SUPER INNEFICIENT, BUT
# IT'S THE MOST READABLE SOLUTION
# Stacking scores by biometric reference
self.t_stats = dict()
"""Computes the mean and std of each SampleSet/identity"""
for sset in t_scores:
self.t_stats[sset.template_id] = Sample(
[s.data for s in sset], parent=sset
)
# Now computing the statistics in place
for key in self.t_stats:
data = self.t_stats[key].data
# Selecting the top scores
data = [s.data for s in sset]
if self.top_norm:
# Sorting in ascending order
data = -np.sort(-data)
......@@ -333,13 +327,10 @@ class TNormScores(TransformerMixin, BaseEstimator):
np.floor(len(data) * self.top_norm_score_fraction)
)
data = data[0:proportion]
self.t_stats[key].mu = np.mean(self.t_stats[key].data)
self.t_stats[key].std = np.std(self.t_stats[key].data)
# self._z_stats[key].std = legacy_std(
# self._z_stats[key].mu, self._z_stats[key].data
# )
self.t_stats[key].data = []
self.t_stats[sset.template_id] = {
"mu": np.mean(data),
"std": np.std(data),
}
return self
......@@ -351,7 +342,7 @@ class TNormScores(TransformerMixin, BaseEstimator):
def _transform_samples(X, stats):
scores = []
for no_normed_score in X:
score = (no_normed_score.data - stats.mu) / stats.std
score = (no_normed_score.data - stats["mu"]) / stats["std"]
t_score = Sample(score, parent=no_normed_score)
scores.append(t_score)
......@@ -372,7 +363,7 @@ class TNormScores(TransformerMixin, BaseEstimator):
)
else:
# If it is Samples
t_normed_scores = _transform_samples(X)
t_normed_scores = _transform_samples(X) # YD2023: Is this needed?
return t_normed_scores
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment