From ed1878961a2698f1cb796410aa93fb93cd0e5306 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Fri, 8 May 2020 17:51:38 +0200
Subject: [PATCH] Got it right the paralellization of ZT-Norm computation

---
 .../pipelines/vanilla_biometrics/__init__.py  |   4 +-
 .../pipelines/vanilla_biometrics/pipelines.py | 227 -----------
 .../pipelines/vanilla_biometrics/wrappers.py  | 131 ++----
 .../pipelines/vanilla_biometrics/zt_norm.py   | 384 ++++++++++++++++++
 .../test_vanilla_biometrics_score_norm.py     | 288 +++++++------
 5 files changed, 573 insertions(+), 461 deletions(-)
 create mode 100644 bob/bio/base/pipelines/vanilla_biometrics/zt_norm.py

diff --git a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py
index dcd7b9e9..e87c415d 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py
@@ -1,7 +1,9 @@
-from .pipelines import VanillaBiometricsPipeline, ZTNormVanillaBiometricsPipeline
+from .pipelines import VanillaBiometricsPipeline
 
 from .biometric_algorithms import Distance
 from .score_writers import FourColumnsScoreWriter, CSVScoreWriter
 from .wrappers import BioAlgorithmCheckpointWrapper, BioAlgorithmDaskWrapper, dask_vanilla_biometrics
 
+from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper
+
 from .legacy import BioAlgorithmLegacy, DatabaseConnector
\ No newline at end of file
diff --git a/bob/bio/base/pipelines/vanilla_biometrics/pipelines.py b/bob/bio/base/pipelines/vanilla_biometrics/pipelines.py
index 0a1387ed..6dd409fb 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/pipelines.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/pipelines.py
@@ -11,8 +11,6 @@ for bob.bio experiments
 import logging
 import numpy
 from .score_writers import FourColumnsScoreWriter
-from .wrappers import BioAlgorithmZTNormWrapper
-
 
 logger = logging.getLogger(__name__)
 
@@ -166,228 +164,3 @@ class VanillaBiometricsPipeline(object):
 
     def write_scores(self, scores):
         return self.score_writer.write(scores)
-
-
-class ZTNormVanillaBiometricsPipeline(object):
-    """
-    Apply Z, T or ZT Score normalization on top of VanillaBiometric Pipeline
-
-    Reference bibliography from: A Generative Model for Score Normalization in Speaker Recognition
-    https://arxiv.org/pdf/1709.09868.pdf
-
-
-    Example
-    -------
-       >>> transformer = make_pipeline([])
-       >>> biometric_algorithm = Distance()
-       >>> vanilla_biometrics_pipeline = VanillaBiometricsPipeline(transformer, biometric_algorithm)
-       >>> zt_pipeline = ZTNormVanillaBiometricsPipeline(vanilla_biometrics_pipeline)
-       >>> zt_pipeline(...)
-
-    Parameters
-    ----------
-
-        vanilla_biometrics_pipeline: :any:`VanillaBiometricsPipeline`
-          An instance :any:`VanillaBiometricsPipeline` to the wrapped with score normalization
-
-        z_norm: bool
-          If True, applies ZScore normalization on top of raw scores.
-
-        t_norm: bool
-          If True, applies TScore normalization on top of raw scores.
-          If both, z_norm and t_norm are true, it applies score normalization
-
-    """
-
-
-    def __init__(self, vanilla_biometrics_pipeline, z_norm=True, t_norm=True):
-        self.vanilla_biometrics_pipeline = vanilla_biometrics_pipeline
-        # Wrapping with ZTNorm
-        self.vanilla_biometrics_pipeline.biometric_algorithm = BioAlgorithmZTNormWrapper(
-            self.vanilla_biometrics_pipeline.biometric_algorithm
-        )
-        self.z_norm = z_norm
-        self.t_norm = t_norm
-
-        if not z_norm and not t_norm:
-            raise ValueError("Both z_norm and t_norm are False. No normalization will be applied")
-
-    def __call__(
-        self,
-        background_model_samples,
-        biometric_reference_samples,
-        probe_samples,
-        zprobe_samples=None,
-        t_biometric_reference_samples=None,
-        allow_scoring_with_all_biometric_references=False,
-    ):
-
-        self.transformer = self.train_background_model(background_model_samples)
-
-        # Create biometric samples
-        biometric_references = self.create_biometric_reference(
-            biometric_reference_samples
-        )
-
-        raw_scores, probe_features = self.compute_scores(
-            probe_samples,
-            biometric_references,
-            allow_scoring_with_all_biometric_references,
-        )
-
-        # Z NORM
-        if self.z_norm:
-            if zprobe_samples is None:
-                raise ValueError("No samples for `z_norm` was provided")
-
-
-            z_normed_scores, z_probe_features = self.compute_znorm_scores(
-                zprobe_samples,
-                raw_scores,
-                biometric_references,
-                allow_scoring_with_all_biometric_references,
-            )
-        if self.t_norm:
-            if t_biometric_reference_samples is None:
-                raise ValueError("No samples for `t_norm` was provided")
-        else:
-            # In case z_norm=True and t_norm=False
-            return z_normed_scores
-
-        # T NORM
-        t_normed_scores, t_scores, t_biometric_references = self.compute_tnorm_scores(
-            t_biometric_reference_samples,
-            probe_features,
-            raw_scores,
-            allow_scoring_with_all_biometric_references,
-        )
-        if not self.z_norm:
-            # In case z_norm=False and t_norm=True
-            return t_normed_scores
-
-
-        # ZT NORM
-        zt_normed_scores = self.compute_ztnorm_scores(
-            z_probe_features,
-            t_biometric_references,
-            z_normed_scores,
-            t_scores,
-            allow_scoring_with_all_biometric_references,
-        )
-
-        return zt_normed_scores
-
-    def train_background_model(self, background_model_samples):
-        return self.vanilla_biometrics_pipeline.train_background_model(
-            background_model_samples
-        )
-
-    def create_biometric_reference(self, biometric_reference_samples):
-        return self.vanilla_biometrics_pipeline.create_biometric_reference(
-            biometric_reference_samples
-        )
-
-    def compute_scores(
-        self,
-        probe_samples,
-        biometric_references,
-        allow_scoring_with_all_biometric_references=False,
-    ):
-
-        return self.vanilla_biometrics_pipeline.compute_scores(
-            probe_samples,
-            biometric_references,
-            allow_scoring_with_all_biometric_references,
-        )
-
-    def _inject_references(self, probe_samples, biometric_references):
-        """
-        Inject references in the current sampleset,
-        so it can run the scores
-        """
-
-        ########## WARNING #######
-        #### I'M MUTATING OBJECTS HERE. THIS CAN GO WRONG
-
-        references = [s.subject  for s in biometric_references]
-        for probe in probe_samples:
-            probe.references = references
-        return probe_samples
-
-
-    def compute_znorm_scores(
-        self,
-        zprobe_samples,
-        probe_scores,
-        biometric_references,
-        allow_scoring_with_all_biometric_references=False,
-    ):
-
-        zprobe_samples = self._inject_references(zprobe_samples, biometric_references)
-
-        z_scores, z_probe_features = self.compute_scores(
-            zprobe_samples, biometric_references
-        )
-
-        z_normed_scores = self.vanilla_biometrics_pipeline.biometric_algorithm.compute_znorm_scores(
-            z_scores, probe_scores, allow_scoring_with_all_biometric_references,
-        )
-
-        return z_normed_scores, z_probe_features
-
-    def compute_tnorm_scores(
-        self,
-        t_biometric_reference_samples,
-        probe_features,
-        probe_scores,
-        allow_scoring_with_all_biometric_references=False,
-    ):
-
-        t_biometric_references = self.create_biometric_reference(
-            t_biometric_reference_samples
-        )
-
-        probe_features = self._inject_references(probe_features, t_biometric_references)
-
-        # Reusing the probe features
-        t_scores = self.vanilla_biometrics_pipeline.biometric_algorithm.score_samples(
-            probe_features,
-            t_biometric_references,
-            allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
-        )
-
-        t_normed_scores = self.vanilla_biometrics_pipeline.biometric_algorithm.compute_tnorm_scores(
-            t_scores, probe_scores, allow_scoring_with_all_biometric_references,
-        )
-
-        return t_normed_scores, t_scores, t_biometric_references
-
-    def compute_ztnorm_scores(self,
-            z_probe_features,
-            t_biometric_references,
-            z_normed_scores,
-            t_scores,
-            allow_scoring_with_all_biometric_references=False
-            ):
-
-        z_probe_features = self._inject_references(z_probe_features, t_biometric_references)
-
-        # Reusing the zprobe_features and t_biometric_references
-        zt_scores = self.vanilla_biometrics_pipeline.biometric_algorithm.score_samples(
-            z_probe_features,
-            t_biometric_references,
-            allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
-        )
-
-        # Z Normalizing the T-normed scores
-        z_normed_t_normed = self.vanilla_biometrics_pipeline.biometric_algorithm.compute_znorm_scores(
-            zt_scores, t_scores, allow_scoring_with_all_biometric_references,
-        )
-
-        # (Z Normalizing the T-normed scores) the Z normed scores
-        zt_normed_scores = self.vanilla_biometrics_pipeline.biometric_algorithm.compute_tnorm_scores(
-            z_normed_t_normed, z_normed_scores, allow_scoring_with_all_biometric_references,
-        )
-
-
-        return zt_normed_scores
diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
index 42c9b23b..98434e26 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
@@ -8,6 +8,7 @@ from .abstract_classes import BioAlgorithm
 import pickle
 import bob.pipelines as mario
 import numpy as np
+from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper
 
 
 class BioAlgorithmCheckpointWrapper(BioAlgorithm):
@@ -121,6 +122,10 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
 
 
 class BioAlgorithmDaskWrapper(BioAlgorithm):
+    """
+    Wrap :any:`BioAlgorithm` to work with DASK
+    """
+
     def __init__(self, biometric_algorithm, **kwargs):
         self.biometric_algorithm = biometric_algorithm
 
@@ -166,7 +171,7 @@ class BioAlgorithmDaskWrapper(BioAlgorithm):
         )
 
 
-def dask_vanilla_biometrics(vanila_biometrics_pipeline, npartitions=None):
+def dask_vanilla_biometrics(pipeline, npartitions=None):
     """
     Given a :any:`VanillaBiometrics`, wraps :any:`VanillaBiometrics.transformer` and
     :any:`VanillaBiometrics.biometric_algorithm` to be executed with dask
@@ -174,123 +179,35 @@ def dask_vanilla_biometrics(vanila_biometrics_pipeline, npartitions=None):
     Parameters
     ----------
 
-    vanila_biometrics_pipeline: :any:`VanillaBiometrics`
+    pipeline: :any:`VanillaBiometrics`
        Vanilla Biometrics based pipeline to be dasked
 
     npartitions: int
        Number of partitions for the initial :any:`dask.bag`
     """
 
-    vanila_biometrics_pipeline.transformer = mario.wrap(
-        ["dask"], vanila_biometrics_pipeline.transformer, npartitions=npartitions
-    )
-    vanila_biometrics_pipeline.biometric_algorithm = BioAlgorithmDaskWrapper(
-        vanila_biometrics_pipeline.biometric_algorithm
-    )
+    if isinstance(pipeline, ZTNormPipeline):
+        # Dasking the first part of the pipelines
+        pipeline = dask_vanilla_biometrics(pipeline.vanila_biometrics_pipeline, npartitions)
 
-    def _write_scores(scores):
-        return scores.map_partitions(vanila_biometrics_pipeline.write_scores_on_dask)
+        pipeline.ztnorm_solver = ZTNormDaskWrapper(pipeline.ztnorm_solver)        
 
-    vanila_biometrics_pipeline.write_scores_on_dask = (
-        vanila_biometrics_pipeline.write_scores
-    )
-    vanila_biometrics_pipeline.write_scores = _write_scores
+    else:
 
-    return vanila_biometrics_pipeline
-
-
-class BioAlgorithmZTNormWrapper(BioAlgorithm):
-    """
-    Wraps an :any:`BioAlgorithm` with ZT score normalization
-    """
-
-    def __init__(self, biometric_algorithm, **kwargs):
-
-        self.biometric_algorithm = biometric_algorithm
-        super().__init__(**kwargs)
-
-    def enroll(self, enroll_features):
-        return self.biometric_algorithm.enroll(enroll_features)
-
-    def score(self, biometric_reference, data):
-        return self.biometric_algorithm.score(biometric_reference, data)
-
-    def score_multiple_biometric_references(self, biometric_references, data):
-        return self.biometric_algorithm.score_multiple_biometric_references(
-            biometric_references, data
+        pipeline.transformer = mario.wrap(
+            ["dask"], pipeline.transformer, npartitions=npartitions
+        )
+        pipeline.biometric_algorithm = BioAlgorithmDaskWrapper(
+            pipeline.biometric_algorithm
         )
 
-    def _norm(self, score, mu, std):
-        return (score - mu) / std
-
-
-    def compute_znorm_scores(
-        self,
-        base_norm_scores,
-        probe_scores,
-        allow_scoring_with_all_biometric_references=False,
-    ):
-        """
-        Base Z-normalization function
-        """
-
-        # Dumping all scores
-        score_floats = np.array([s.data for sset in base_norm_scores for s in sset])
-
-        # Reshaping in PROBE vs BIOMETRIC_REFERENCES
-        n_probes = len(base_norm_scores)
-        n_references = len(base_norm_scores[0].references)
-        score_floats = score_floats.reshape((n_probes, n_references))
-
-        # AXIS ON THE MODELS
-        big_mu = np.mean(score_floats, axis=0)
-        big_std = np.std(score_floats, axis=0)
-
-        # Normalizing
-        # TODO: THIS TENDS TO BE EXTREMLY SLOW
-        normed_score_samples = []
-        for probe in probe_scores:
-            sampleset = SampleSet([], parent=probe)
-            for mu, std, biometric_reference_score in zip(big_mu, big_std, probe):
-                score = self._norm(biometric_reference_score.data, mu, std)
-                new_sample = Sample(score, parent=biometric_reference_score)
-                sampleset.samples.append(new_sample)
-            normed_score_samples.append(sampleset)
-
-        return normed_score_samples
+        def _write_scores(scores):
+            return scores.map_partitions(pipeline.write_scores_on_dask)
 
+        pipeline.write_scores_on_dask = (
+            pipeline.write_scores
+        )
+        pipeline.write_scores = _write_scores
 
-    def compute_tnorm_scores(
-        self,
-        base_norm_scores,
-        probe_scores,
-        allow_scoring_with_all_biometric_references=False,
-    ):
-        """
-        Base Z-normalization function
-        """
+    return pipeline
 
-        # Dumping all scores
-        score_floats = np.array([s.data for sset in base_norm_scores for s in sset])
-
-        # Reshaping in PROBE vs BIOMETRIC_REFERENCES
-        n_probes = len(base_norm_scores)
-        n_references = len(base_norm_scores[0].references)
-        score_floats = score_floats.reshape((n_probes, n_references))
-
-        # AXIS ON THE PROBES
-        big_mu = np.mean(score_floats, axis=1)
-        big_std = np.std(score_floats, axis=1)
-
-        # Normalizing
-        # TODO: THIS TENDS TO BE EXTREMLY SLOW
-        normed_score_samples = []
-        for mu, std, probe in zip(big_mu, big_std,probe_scores):
-            sampleset = SampleSet([], parent=probe)
-            for biometric_reference_score in probe:
-                score = self._norm(biometric_reference_score.data, mu, std)
-                new_sample = Sample(score, parent=biometric_reference_score)
-                sampleset.samples.append(new_sample)
-            normed_score_samples.append(sampleset)
-
-        return normed_score_samples
diff --git a/bob/bio/base/pipelines/vanilla_biometrics/zt_norm.py b/bob/bio/base/pipelines/vanilla_biometrics/zt_norm.py
new file mode 100644
index 00000000..5d0e4248
--- /dev/null
+++ b/bob/bio/base/pipelines/vanilla_biometrics/zt_norm.py
@@ -0,0 +1,384 @@
+"""
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+
+Implementation of a pipeline and an algorithm that 
+computes Z, T and ZT Score Normalization of a :any:`BioAlgorithm`
+"""
+
+from bob.pipelines import DelayedSample, Sample, SampleSet
+import numpy as np
+import dask
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+class ZTNormPipeline(object):
+    """
+    Apply Z, T or ZT Score normalization on top of VanillaBiometric Pipeline
+
+    Reference bibliography from: A Generative Model for Score Normalization in Speaker Recognition
+    https://arxiv.org/pdf/1709.09868.pdf
+
+
+    Example
+    -------
+       >>> transformer = make_pipeline([])
+       >>> biometric_algorithm = Distance()
+       >>> vanilla_biometrics_pipeline = VanillaBiometricsPipeline(transformer, biometric_algorithm)
+       >>> zt_pipeline = ZTNormVanillaBiometricsPipeline(vanilla_biometrics_pipeline)
+       >>> zt_pipeline(...)
+
+    Parameters
+    ----------
+
+        vanilla_biometrics_pipeline: :any:`VanillaBiometricsPipeline`
+          An instance :any:`VanillaBiometricsPipeline` to the wrapped with score normalization
+
+        z_norm: bool
+          If True, applies ZScore normalization on top of raw scores.
+
+        t_norm: bool
+          If True, applies TScore normalization on top of raw scores.
+          If both, z_norm and t_norm are true, it applies score normalization
+
+    """
+
+    def __init__(self, vanilla_biometrics_pipeline, z_norm=True, t_norm=True):
+        self.vanilla_biometrics_pipeline = vanilla_biometrics_pipeline
+
+        self.ztnorm_solver = ZTNorm()
+
+        self.z_norm = z_norm
+        self.t_norm = t_norm
+
+        if not z_norm and not t_norm:
+            raise ValueError(
+                "Both z_norm and t_norm are False. No normalization will be applied"
+            )
+
+    def __call__(
+        self,
+        background_model_samples,
+        biometric_reference_samples,
+        probe_samples,
+        zprobe_samples=None,
+        t_biometric_reference_samples=None,
+        allow_scoring_with_all_biometric_references=False,
+    ):
+
+        self.transformer = self.train_background_model(background_model_samples)
+
+        # Create biometric samples
+        biometric_references = self.create_biometric_reference(
+            biometric_reference_samples
+        )
+
+        raw_scores, probe_features = self.compute_scores(
+            probe_samples,
+            biometric_references,
+            allow_scoring_with_all_biometric_references,
+        )
+
+        # Z NORM
+        if self.z_norm:
+            if zprobe_samples is None:
+                raise ValueError("No samples for `z_norm` was provided")
+
+            z_normed_scores, z_probe_features = self.compute_znorm_scores(
+                zprobe_samples,
+                raw_scores,
+                biometric_references,
+                allow_scoring_with_all_biometric_references,
+            )
+        if self.t_norm:
+            if t_biometric_reference_samples is None:
+                raise ValueError("No samples for `t_norm` was provided")
+        else:
+            # In case z_norm=True and t_norm=False
+            return z_normed_scores
+
+        # T NORM
+        t_normed_scores, t_scores, t_biometric_references = self.compute_tnorm_scores(
+            t_biometric_reference_samples,
+            probe_features,
+            raw_scores,
+            allow_scoring_with_all_biometric_references,
+        )
+        if not self.z_norm:
+            # In case z_norm=False and t_norm=True
+            return t_normed_scores
+
+        # ZT NORM
+        zt_normed_scores = self.compute_ztnorm_scores(
+            z_probe_features,
+            t_biometric_references,
+            z_normed_scores,
+            t_scores,
+            allow_scoring_with_all_biometric_references,
+        )
+
+        return zt_normed_scores
+
+    def train_background_model(self, background_model_samples):
+        return self.vanilla_biometrics_pipeline.train_background_model(
+            background_model_samples
+        )
+
+    def create_biometric_reference(self, biometric_reference_samples):
+        return self.vanilla_biometrics_pipeline.create_biometric_reference(
+            biometric_reference_samples
+        )
+
+    def compute_scores(
+        self,
+        probe_samples,
+        biometric_references,
+        allow_scoring_with_all_biometric_references=False,
+    ):
+
+        return self.vanilla_biometrics_pipeline.compute_scores(
+            probe_samples,
+            biometric_references,
+            allow_scoring_with_all_biometric_references,
+        )
+
+    def _inject_references(self, probe_samples, biometric_references):
+        """
+        Inject references in the current sampleset,
+        so it can run the scores
+        """
+
+        ########## WARNING #######
+        #### I'M MUTATING OBJECTS HERE. THIS CAN GO WRONG
+        references = [s.subject for s in biometric_references]
+        for probe in probe_samples:
+            probe.references = references
+        return probe_samples
+
+    def compute_znorm_scores(
+        self,
+        zprobe_samples,
+        probe_scores,
+        biometric_references,
+        allow_scoring_with_all_biometric_references=False,
+    ):
+
+        #zprobe_samples = self._inject_references(zprobe_samples, biometric_references)
+
+        z_scores, z_probe_features = self.compute_scores(
+            zprobe_samples, biometric_references
+        )
+
+        z_normed_scores = self.ztnorm_solver.compute_znorm_scores(
+            probe_scores, z_scores, biometric_references
+        )
+
+        return z_normed_scores, z_probe_features
+
+    def compute_tnorm_scores(
+        self,
+        t_biometric_reference_samples,
+        probe_features,
+        probe_scores,
+        allow_scoring_with_all_biometric_references=False,
+    ):
+
+        t_biometric_references = self.create_biometric_reference(
+            t_biometric_reference_samples
+        )
+
+        #probe_features = self._inject_references(probe_features, t_biometric_references)
+
+        # Reusing the probe features
+        t_scores = self.vanilla_biometrics_pipeline.biometric_algorithm.score_samples(
+            probe_features,
+            t_biometric_references,
+            allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
+        )
+
+        t_normed_scores = self.ztnorm_solver.compute_tnorm_scores(
+            probe_scores, t_scores, t_biometric_references
+        )
+
+        return t_normed_scores, t_scores, t_biometric_references
+
+    def compute_ztnorm_scores(
+        self,
+        z_probe_features,
+        t_biometric_references,
+        z_normed_scores,
+        t_scores,
+        allow_scoring_with_all_biometric_references=False,
+    ):
+
+        #z_probe_features = self._inject_references(
+        #    z_probe_features, t_biometric_references
+        #)
+
+        # Reusing the zprobe_features and t_biometric_references
+        zt_scores = self.vanilla_biometrics_pipeline.biometric_algorithm.score_samples(
+            z_probe_features,
+            t_biometric_references,
+            allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
+        )
+
+        # Z Normalizing the T-normed scores
+        z_normed_t_normed = self.ztnorm_solver.compute_znorm_scores(t_scores, zt_scores, t_biometric_references)
+
+        # (Z Normalizing the T-normed scores) the Z normed scores
+        zt_normed_scores = self.ztnorm_solver.compute_tnorm_scores(
+            z_normed_scores, z_normed_t_normed, t_biometric_references
+        )
+
+        return zt_normed_scores
+
+
+class ZTNorm(object):
+    """
+    Computes Z, T and ZT Score Normalization of a :any:`BioAlgorithm`
+
+    Reference bibliography from: A Generative Model for Score Normalization in Speaker Recognition
+    https://arxiv.org/pdf/1709.09868.pdf
+
+    """
+
+    def _norm(self, score, mu, std):
+        return (score - mu) / std
+
+    def _compute_stats(self, sampleset_for_norm, biometric_references, axis=0):
+        """
+        Compute statistics for Z and T Norm.
+
+        The way the scores are organized (probe vs bioref);
+        axis=0 computes CORRECTLY the statistics for ZNorm
+        axis=1 computes CORRECTLY the statistics for TNorm
+        """
+
+        # Dumping all scores
+        score_floats = np.array([s.data for sset in sampleset_for_norm for s in sset])
+
+        # Reshaping in PROBE vs BIOMETRIC_REFERENCES
+        n_probes = len(sampleset_for_norm)
+        n_references = len(biometric_references)
+        score_floats = score_floats.reshape((n_probes, n_references))
+
+        # AXIS ON THE MODELS
+        big_mu = np.mean(score_floats, axis=axis)
+        big_std = np.std(score_floats, axis=axis)
+
+        # Creating statistics structure with subject id as the key
+        stats = {}        
+        if axis==0:
+            for mu, std, s in zip(big_mu, big_std, sampleset_for_norm[0]):
+                stats[s.subject] ={"big_mu": mu, "big_std": std}
+        else:
+            for mu, std, sset in zip(big_mu, big_std, sampleset_for_norm):
+                stats[sset.subject] ={"big_mu": mu, "big_std": std}
+        
+        return stats
+
+    def _apply_znorm(self, probe_scores, stats):
+        # Normalizing
+        # TODO: THIS TENDS TO BE EXTREMLY SLOW
+
+        normed_score_samples = []
+        for probe in probe_scores:
+            sampleset = SampleSet([], parent=probe)
+            for biometric_reference_score in probe:
+                
+                mu = stats[biometric_reference_score.subject]["big_mu"]
+                std = stats[biometric_reference_score.subject]["big_std"]
+
+                score = self._norm(biometric_reference_score.data, mu, std)
+                new_sample = Sample(score, parent=biometric_reference_score)
+                sampleset.samples.append(new_sample)
+            normed_score_samples.append(sampleset)
+
+        return normed_score_samples
+
+    def _apply_tnorm(self, probe_scores, stats):
+        # Normalizing
+        # TODO: THIS TENDS TO BE EXTREMLY SLOW
+        # MAYBE THIS COULD BE DELAYED OR RUN ON TOP OF
+
+        normed_score_samples = []
+        for probe in probe_scores:
+            sampleset = SampleSet([], parent=probe)
+
+            mu = stats[probe.subject]["big_mu"]
+            std = stats[probe.subject]["big_std"]
+
+            for biometric_reference_score in probe:
+                score = self._norm(biometric_reference_score.data, mu, std)
+                new_sample = Sample(score, parent=biometric_reference_score, xuxa=biometric_reference_score.data)
+                sampleset.samples.append(new_sample)
+            normed_score_samples.append(sampleset)
+
+        return normed_score_samples
+
+    def compute_znorm_scores(self, probe_scores, sampleset_for_znorm, biometric_references):
+        """
+        Base Z-normalization function
+        """
+
+        stats = self._compute_stats(sampleset_for_znorm, biometric_references, axis=0)
+
+        return self._apply_znorm(probe_scores, stats)
+
+    def compute_tnorm_scores(
+        self,
+        probe_scores,
+        sampleset_for_tnorm,
+        t_biometric_references,
+        allow_scoring_with_all_biometric_references=False,
+    ):
+        """
+        Base T-normalization function
+        """
+
+        stats = self._compute_stats(sampleset_for_tnorm, t_biometric_references, axis=1)
+
+        return self._apply_tnorm(probe_scores, stats)
+
+
+class ZTNormDaskWrapper(object):
+    """
+    Wrap :any:`ZTNormPipeline` to work with DASK
+
+    Parameters
+    ----------
+
+        ztnorm_pipeline: :any:`ZTNormPipeline`
+            ZTNorm Pipeline
+    """
+
+    def __init__(self, ztnorm):
+
+        if not isinstance(ztnorm, ZTNorm):
+            raise ValueError("This class only wraps `ZTNorm` objects")
+
+        self.ztnorm = ztnorm
+
+    def compute_znorm_scores(
+        self, probe_scores, sampleset_for_znorm, biometric_references
+    ):
+
+        # Reducing all the Z-Scores to compute the stats
+        all_scores_for_znorm = dask.delayed(list)(sampleset_for_znorm)
+
+        stats = dask.delayed(self.ztnorm._compute_stats)(all_scores_for_znorm, biometric_references, axis=0)
+
+        return probe_scores.map_partitions(self.ztnorm._apply_znorm, stats)
+
+    def compute_tnorm_scores(
+        self, probe_scores, sampleset_for_tnorm, t_biometric_references
+    ):
+
+        # Reducing all the Z-Scores to compute the stats
+        all_scores_for_tnorm = dask.delayed(list)(sampleset_for_tnorm)
+
+        stats = dask.delayed(self.ztnorm._compute_stats)(all_scores_for_tnorm, t_biometric_references, axis=1)
+
+        return probe_scores.map_partitions(self.ztnorm._apply_tnorm, stats)
diff --git a/bob/bio/base/test/test_vanilla_biometrics_score_norm.py b/bob/bio/base/test/test_vanilla_biometrics_score_norm.py
index 2ae10893..193399c7 100644
--- a/bob/bio/base/test/test_vanilla_biometrics_score_norm.py
+++ b/bob/bio/base/test/test_vanilla_biometrics_score_norm.py
@@ -20,7 +20,8 @@ from bob.bio.base.test.test_vanilla_biometrics import DummyDatabase, _make_trans
 from bob.bio.base.pipelines.vanilla_biometrics import (
     Distance,
     VanillaBiometricsPipeline,
-    ZTNormVanillaBiometricsPipeline,
+    ZTNormPipeline,
+    ZTNormDaskWrapper,
     BioAlgorithmCheckpointWrapper,
     dask_vanilla_biometrics,
     BioAlgorithmLegacy,
@@ -108,127 +109,162 @@ def test_norm_mechanics():
         # and bob.bio.base is PROBES vs BIOMETRIC_REFERENCES
         return np.array([s.data for sset in scores for s in sset]).reshape(shape).T
 
-    ############
-    # Prepating stubs
-    ############
-    n_references = 2
-    n_probes = 3
-    n_t_references = 4
-    n_z_probes = 5
-
-    references = np.arange(10).reshape(
-        n_references, 5
-    )  # two references (each row different identity)
-    probes = (
-        np.arange(15).reshape(n_probes, 5) * 10
-    )  # three probes (each row different identity matching with references)
-
-    t_references = np.arange(20).reshape(
-        n_t_references, 5
-    )  # four T-REFERENCES (each row different identity)
-    z_probes = (
-        np.arange(25).reshape(n_z_probes, 5) * 10
-    )  # five Z-PROBES (each row different identity matching with t references)
-
-    (
-        raw_scores_ref,
-        z_normed_scores_ref,
-        t_normed_scores_ref,
-        zt_normed_scores_ref,
-    ) = zt_norm_stubs(references, probes, t_references, z_probes)
-
-    ############
-    # Preparing the samples
-    ############
-
-    biometric_reference_sample_sets = _create_sample_sets(references, 0)
-    reference_ids = [r.subject for r in biometric_reference_sample_sets]
-
-    probe_sample_sets = _create_sample_sets(probes, 10, reference_ids)
-
-    t_reference_sample_sets = _create_sample_sets(t_references, 20)
-    t_reference_ids = [r.subject for r in t_reference_sample_sets]
-
-    #z_probe_sample_sets = _create_sample_sets(z_probes, 30, t_reference_ids)
-    z_probe_sample_sets = _create_sample_sets(z_probes, 30, t_reference_ids)
-
-    ############
-    # TESTING REGULAR SCORING
-    #############
-    transformer = FunctionTransformer(func=_do_nothing_fn)
-    biometric_algorithm = Distance(factor=1)
-    vanilla_pipeline = VanillaBiometricsPipeline(
-        transformer, biometric_algorithm, score_writer=None
-    )
-    score_sampes = vanilla_pipeline(
-        [], biometric_reference_sample_sets, probe_sample_sets,
-        allow_scoring_with_all_biometric_references=True
-    )
-
-    raw_scores = _dump_scores_from_samples(score_sampes, shape=(n_probes, n_references))
-
-    assert np.allclose(raw_scores, raw_scores_ref)
-
-
-    ############
-    # TESTING Z-NORM
-    #############
-    z_vanilla_pipeline = ZTNormVanillaBiometricsPipeline(vanilla_pipeline,
-        z_norm=True,
-        t_norm=False,
-    )
-
-    z_normed_score_samples = z_vanilla_pipeline(
-        [],
-        biometric_reference_sample_sets,
-        copy.deepcopy(probe_sample_sets),
-        z_probe_sample_sets,
-        t_reference_sample_sets,
-    )
-
-    z_normed_scores = _dump_scores_from_samples(z_normed_score_samples, shape=(n_probes, n_references))
-    assert np.allclose(z_normed_scores, z_normed_scores_ref)
-
-    ############
-    # TESTING T-NORM
-    #############
-    t_vanilla_pipeline = ZTNormVanillaBiometricsPipeline(vanilla_pipeline,
-        z_norm=False,
-        t_norm=True,
-    )
-
-    t_normed_score_samples = t_vanilla_pipeline(
-        [],
-        biometric_reference_sample_sets,
-        copy.deepcopy(probe_sample_sets),
-        z_probe_sample_sets,
-        t_reference_sample_sets,
-    )
-
-    t_normed_scores = _dump_scores_from_samples(t_normed_score_samples, shape=(n_probes, n_references))
-    assert np.allclose(t_normed_scores, t_normed_scores_ref)
-
-
-    ############
-    # TESTING ZT-NORM
-    #############
-    zt_vanilla_pipeline = ZTNormVanillaBiometricsPipeline(vanilla_pipeline,
-        z_norm=True,
-        t_norm=True,
-    )
-
-    zt_normed_score_samples = zt_vanilla_pipeline(
-        [],
-        biometric_reference_sample_sets,
-        copy.deepcopy(probe_sample_sets),
-        z_probe_sample_sets,
-        t_reference_sample_sets,
-    )
-
-    zt_normed_scores = _dump_scores_from_samples(zt_normed_score_samples, shape=(n_probes, n_references))
-    assert np.allclose(zt_normed_scores, zt_normed_scores_ref)
-
-
+    def run(with_dask):
+        ############
+        # Prepating stubs
+        ############
+        n_references = 10
+        n_probes = 34
+        n_t_references = 44
+        n_z_probes = 15
+        dim = 5
+
+        references = np.arange(n_references*dim).reshape(
+            n_references, dim
+        )  # two references (each row different identity)
+        probes = (
+            np.arange(n_probes * dim).reshape(n_probes, dim) * 10
+        )  # three probes (each row different identity matching with references)
+
+        t_references = np.arange(n_t_references *dim).reshape(
+            n_t_references, dim
+        )  # four T-REFERENCES (each row different identity)
+        z_probes = (
+            np.arange(n_z_probes * dim).reshape(n_z_probes, dim) * 10
+        )  # five Z-PROBES (each row different identity matching with t references)
+
+        (
+            raw_scores_ref,
+            z_normed_scores_ref,
+            t_normed_scores_ref,
+            zt_normed_scores_ref,
+        ) = zt_norm_stubs(references, probes, t_references, z_probes)
+
+        ############
+        # Preparing the samples
+        ############
+
+        # Creating enrollment samples
+        biometric_reference_sample_sets = _create_sample_sets(references, 0)
+        t_reference_sample_sets = _create_sample_sets(t_references, 20)
+
+        # Fetching ids
+        reference_ids = [r.subject for r in biometric_reference_sample_sets]
+        t_reference_ids = [r.subject for r in t_reference_sample_sets]
+        ids = reference_ids + t_reference_ids
+
+        probe_sample_sets = _create_sample_sets(probes, 10, ids)
+        z_probe_sample_sets = _create_sample_sets(z_probes, 30, ids)
+
+
+        ############
+        # TESTING REGULAR SCORING
+        #############
+
+        transformer = make_pipeline(FunctionTransformer(func=_do_nothing_fn))
+        biometric_algorithm = Distance(factor=1)
+
+        vanilla_pipeline = VanillaBiometricsPipeline(
+            transformer, biometric_algorithm, score_writer=None
+        )        
+        if with_dask:
+            vanilla_pipeline = dask_vanilla_biometrics(vanilla_pipeline)
+
+
+        score_samples = vanilla_pipeline(
+            [], biometric_reference_sample_sets, probe_sample_sets,
+            allow_scoring_with_all_biometric_references=True
+        )
+
+        if with_dask:
+            score_samples = score_samples.compute(scheduler="single-threaded")
+
+        raw_scores = _dump_scores_from_samples(score_samples, shape=(n_probes, n_references))
+        assert np.allclose(raw_scores, raw_scores_ref)
+
+        ############
+        # TESTING Z-NORM
+        #############
+        
+        z_vanilla_pipeline = ZTNormPipeline(vanilla_pipeline,
+            z_norm=True,
+            t_norm=False,
+        )
+        if with_dask:
+            z_vanilla_pipeline.ztnorm_solver = ZTNormDaskWrapper(z_vanilla_pipeline.ztnorm_solver)
+
+        z_normed_score_samples = z_vanilla_pipeline(
+            [],
+            biometric_reference_sample_sets,
+            copy.deepcopy(probe_sample_sets),
+            z_probe_sample_sets,
+            t_reference_sample_sets,
+        )        
+
+        if with_dask:
+            z_normed_score_samples = z_normed_score_samples.compute(scheduler="single-threaded")
+
+        z_normed_scores = _dump_scores_from_samples(z_normed_score_samples, shape=(n_probes, n_references))
+        assert np.allclose(z_normed_scores, z_normed_scores_ref)
+        
+
+        ############
+        # TESTING T-NORM
+        #############
+        
+        t_vanilla_pipeline = ZTNormPipeline(vanilla_pipeline,
+            z_norm=False,
+            t_norm=True,
+        )
+        if with_dask:
+            t_vanilla_pipeline.ztnorm_solver = ZTNormDaskWrapper(t_vanilla_pipeline.ztnorm_solver)
+
+        t_normed_score_samples = t_vanilla_pipeline(
+            [],
+            biometric_reference_sample_sets,
+            copy.deepcopy(probe_sample_sets),
+            z_probe_sample_sets,
+            t_reference_sample_sets,
+        )
+        
+        if with_dask:
+            t_normed_score_samples = t_normed_score_samples.compute(scheduler='single-threaded')
+
+        t_normed_scores = _dump_scores_from_samples(t_normed_score_samples, shape=(n_probes, n_references))
+        assert np.allclose(t_normed_scores, t_normed_scores_ref)
+        
+        
+        ############
+        # TESTING ZT-NORM
+        #############
+        zt_vanilla_pipeline = ZTNormPipeline(vanilla_pipeline,
+            z_norm=True,
+            t_norm=True,
+        )
+
+        if with_dask:
+            zt_vanilla_pipeline.ztnorm_solver = ZTNormDaskWrapper(zt_vanilla_pipeline.ztnorm_solver)
+
+        zt_normed_score_samples = zt_vanilla_pipeline(
+            [],
+            biometric_reference_sample_sets,
+            copy.deepcopy(probe_sample_sets),
+            z_probe_sample_sets,
+            t_reference_sample_sets,
+        )
+
+        if with_dask:
+            zt_normed_score_samples = zt_normed_score_samples.compute()
+
+        zt_normed_scores = _dump_scores_from_samples(zt_normed_score_samples, shape=(n_probes, n_references))
+
+        assert np.allclose(zt_normed_scores, zt_normed_scores_ref)
+        
+    # No dask
+    run(False)
+
+    # With dask
+    run(True)
 
 
 def test_znorm_on_memory():
@@ -243,7 +279,7 @@ def test_znorm_on_memory():
 
             biometric_algorithm = Distance()
 
-            vanilla_biometrics_pipeline = ZTNormVanillaBiometricsPipeline(
+            vanilla_biometrics_pipeline = ZTNormPipeline(
                 VanillaBiometricsPipeline(transformer, biometric_algorithm)
             )
 
@@ -267,8 +303,8 @@ def test_znorm_on_memory():
             assert len(scores) == 10
 
         run_pipeline(False)
-        # run_pipeline(False)  # Testing checkpoint
+        run_pipeline(False)  # Testing checkpoint
         # shutil.rmtree(dir_name)  # Deleting the cache so it runs again from scratch
         # os.makedirs(dir_name, exist_ok=True)
-        # run_pipeline(True)
-        # run_pipeline(True)  # Testing checkpoint
+        #run_pipeline(True)
+        #run_pipeline(True)  # Testing checkpoint
-- 
GitLab