From 735ad313b49bae8c4986c3c9a9068e50e7ca6e20 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 30 Apr 2020 20:02:34 +0200
Subject: [PATCH] Reorganing the structure of vanilla_biometrics

---
 .../pipelines/vanilla_biometrics/__init__.py  |   6 +-
 .../vanilla_biometrics/abstract_classes.py    |  51 ++++---
 ...implemented.py => biometric_algorithms.py} |   8 +-
 .../pipelines/vanilla_biometrics/legacy.py    |   1 -
 .../pipelines/vanilla_biometrics/mixins.py    | 122 -----------------
 .../{pipeline.py => pipelines.py}             |   6 +-
 bob/bio/base/test/test_transformers.py        |  22 +--
 bob/bio/base/test/test_vanilla_biometrics.py  | 126 +++++++++++++-----
 bob/bio/base/transformers/algorithm.py        |  12 +-
 bob/bio/base/transformers/extractor.py        |   3 +-
 bob/bio/base/wrappers.py                      |   6 +-
 11 files changed, 165 insertions(+), 198 deletions(-)
 rename bob/bio/base/pipelines/vanilla_biometrics/{implemented.py => biometric_algorithms.py} (94%)
 delete mode 100644 bob/bio/base/pipelines/vanilla_biometrics/mixins.py
 rename bob/bio/base/pipelines/vanilla_biometrics/{pipeline.py => pipelines.py} (99%)

diff --git a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py
index d6783573..ba5ef249 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py
@@ -1,6 +1,10 @@
 # see https://docs.python.org/3/library/pkgutil.html
 from pkgutil import extend_path
 
-from .pipeline import VanillaBiometrics, dask_vanilla_biometrics
+from .pipelines import VanillaBiometricsPipeline
+from .biometric_algorithms import Distance
+from .score_writers import FourColumnsScoreWriter
+from .wrappers import BioAlgorithmCheckpointWrapper
+
 
 __path__ = extend_path(__path__, __name__)
diff --git a/bob/bio/base/pipelines/vanilla_biometrics/abstract_classes.py b/bob/bio/base/pipelines/vanilla_biometrics/abstract_classes.py
index e27c27aa..022725bc 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/abstract_classes.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/abstract_classes.py
@@ -1,3 +1,7 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+
+
 from abc import ABCMeta, abstractmethod
 from bob.pipelines.sample import Sample, SampleSet, DelayedSample
 import functools
@@ -11,7 +15,8 @@ class BioAlgorithm(metaclass=ABCMeta):
 
     Parameters
     ----------
-
+      allow_score_multiple_references: bool
+         If true, your scoring function can be executed by :any:`BioAlgorithm.score_multiple_biometric_references`
 
     """
 
@@ -111,10 +116,6 @@ class BioAlgorithm(metaclass=ABCMeta):
         # To be honest, this should be the default behaviour
         retval = []
 
-        def _write_sample(ref, probe, score):
-            data = make_four_colums_score(ref.subject, probe.subject, probe.path, score)
-            return Sample(data, parent=ref)
-
         for subprobe_id, (s, parent) in enumerate(zip(data, sampleset.samples)):
             # Creating one sample per comparison
             subprobe_scores = []
@@ -129,20 +130,20 @@ class BioAlgorithm(metaclass=ABCMeta):
                     self.stacked_biometric_references, s
                 )
                 
-                # Wrapping the scores in samples
-                for ref, score in zip(biometric_references, scores):
-                    subprobe_scores.append(_write_sample(ref, sampleset, score))
+                # Wrapping the scores in samples                
+                for ref, score in zip(biometric_references, scores):                    
+                    subprobe_scores.append(Sample(score, parent=ref))
             else:
 
                 for ref in [
                     r for r in biometric_references if r.key in sampleset.references
                 ]:
                     score = self.score(ref.data, s)
-                    subprobe_scores.append(_write_sample(ref, sampleset, score))
+                    subprobe_scores.append(Sample(score, parent=ref))
 
             # Creating one sampleset per probe
-            subprobe = SampleSet(subprobe_scores, parent=sampleset)
-            subprobe.subprobe_id = subprobe_id
+            subprobe = SampleSet(subprobe_scores, parent=parent)
+            subprobe.subject = sampleset.subject            
             retval.append(subprobe)
 
         return retval
@@ -245,13 +246,27 @@ class Database(metaclass=ABCMeta):
         pass
 
 
-def make_four_colums_score(
-    biometric_reference_subject, probe_subject, probe_path, score,
-):
-    data = "{0} {1} {2} {3}\n".format(
-        biometric_reference_subject, probe_subject, probe_path, score,
-    )
-    return data
+class ScoreWriter(metaclass=ABCMeta):
+    """
+    Defines base methods to read, write scores and concatenate scores
+    for :any:`BioAlgorithm`
+    """
+
+    def __init__(self, extension=".txt"):
+        self.extension = extension
+
+    @abstractmethod
+    def write(self, sampleset, path):
+        pass
+
+
+    @abstractmethod
+    def read(self, path):
+        pass
+
+    @abstractmethod
+    def concatenate_write_scores(self, sampleset, path):
+        pass
 
 
 def create_score_delayed_sample(path, probe):
diff --git a/bob/bio/base/pipelines/vanilla_biometrics/implemented.py b/bob/bio/base/pipelines/vanilla_biometrics/biometric_algorithms.py
similarity index 94%
rename from bob/bio/base/pipelines/vanilla_biometrics/implemented.py
rename to bob/bio/base/pipelines/vanilla_biometrics/biometric_algorithms.py
index f77e66d3..233c9863 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/implemented.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/biometric_algorithms.py
@@ -2,8 +2,10 @@ import scipy.spatial.distance
 from sklearn.utils.validation import check_array
 import numpy
 from .abstract_classes import BioAlgorithm
-from .mixins import BioAlgCheckpointMixin
 from scipy.spatial.distance import cdist
+import os
+from bob.pipelines import DelayedSample, Sample, SampleSet
+import functools
 
 
 class Distance(BioAlgorithm):
@@ -69,7 +71,3 @@ class Distance(BioAlgorithm):
         )
 
         return list(scores.flatten())
-
-
-class CheckpointDistance(BioAlgCheckpointMixin, Distance):
-    pass
diff --git a/bob/bio/base/pipelines/vanilla_biometrics/legacy.py b/bob/bio/base/pipelines/vanilla_biometrics/legacy.py
index 9eec5a2a..6b44b591 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/legacy.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/legacy.py
@@ -17,7 +17,6 @@ from .abstract_classes import (
 from bob.io.base import HDF5File
 from bob.pipelines.mixins import SampleMixin, CheckpointMixin
 from bob.pipelines.sample import DelayedSample, SampleSet, Sample
-from sklearn.base import TransformerMixin, BaseEstimator
 import logging
 import copy
 
diff --git a/bob/bio/base/pipelines/vanilla_biometrics/mixins.py b/bob/bio/base/pipelines/vanilla_biometrics/mixins.py
deleted file mode 100644
index d21bc9c8..00000000
--- a/bob/bio/base/pipelines/vanilla_biometrics/mixins.py
+++ /dev/null
@@ -1,122 +0,0 @@
-from bob.pipelines.mixins import CheckpointMixin
-from bob.pipelines.sample import DelayedSample
-import bob.io.base
-import os
-import functools
-import dask
-from .abstract_classes import create_score_delayed_sample
-
-
-class BioAlgCheckpointMixin(CheckpointMixin):
-    """Mixing used to checkpoint Enrolled and Scoring samples.
-
-    Examples
-    --------
-
-    >>> from bob.bio.base.pipelines.vanilla_biometrics.biometric_algorithm import BioAlgCheckpointMixin, Distance
-    >>> class DistanceCheckpoint(BioAlgCheckpointMixin, Distance) pass:
-    >>> biometric_algorithm = DistanceCheckpoint(features_dir="./")
-    >>> biometric_algorithm.enroll(sample)
-
-    It's possible to use it as with the :py:func:`bob.pipelines.mixins.mix_me_up`
-
-    >>> from bob.pipelines.mixins import mix_me_up
-    >>> biometric_algorithm = mix_me_up([BioAlgCheckpointMixin], Distance)(features_dir="./")
-    >>> biometric_algorithm.enroll(sample)
-
-    """
-
-    def __init__(self, features_dir="", **kwargs):
-        super().__init__(features_dir=features_dir, **kwargs)
-        self.biometric_reference_dir = os.path.join(
-            features_dir, "biometric_references"
-        )
-        self.score_dir = os.path.join(features_dir, "scores")
-
-    def save(self, sample, path):
-        return bob.io.base.save(sample.data, path, create_directories=True)
-
-    def _enroll_sample_set(self, sampleset):
-        """
-        Enroll a sample set with checkpointing
-        """
-
-        # Amending `models` directory
-        path = os.path.join(
-            self.biometric_reference_dir, str(sampleset.key) + self.extension
-        )
-        if path is None or not os.path.isfile(path):
-
-            # Enrolling the sample
-            enrolled_sample = super()._enroll_sample_set(sampleset)
-
-            # saving the new sample
-            self.save(enrolled_sample, path)
-
-            # Dealaying it.
-            # This seems inefficient, but it's crucial for large datasets
-            delayed_enrolled_sample = DelayedSample(
-                functools.partial(bob.io.base.load, path), enrolled_sample
-            )
-
-        else:
-            # If sample already there, just load
-            delayed_enrolled_sample = self.load(sampleset, path)
-
-        return delayed_enrolled_sample
-
-    def _score_sample_set(
-        self,
-        sampleset,
-        biometric_references,
-        allow_scoring_with_all_biometric_references=False
-    ):
-        """Given a sampleset for probing, compute the scores and retures a sample set with the scores
-        """
-
-        # Computing score
-        scored_sample_set = super()._score_sample_set(
-            sampleset,
-            biometric_references,
-            allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
-        )
-        for s in scored_sample_set:
-            # Checkpointing score
-            path = os.path.join(self.score_dir, str(s.path) + ".txt")
-            os.makedirs(os.path.dirname(path), exist_ok=True)
-
-            delayed_scored_sample = create_score_delayed_sample(path, s)
-            s.samples = [delayed_scored_sample]
-
-        return scored_sample_set
-
-
-class BioAlgDaskMixin:
-    def enroll_samples(self, biometric_reference_features):
-        biometric_references = biometric_reference_features.map_partitions(
-            super().enroll_samples
-        )
-        return biometric_references
-
-    def score_samples(
-        self,
-        probe_features,
-        biometric_references,
-        allow_scoring_with_all_biometric_references=False,
-    ):
-
-        # TODO: Here, we are sending all computed biometric references to all
-        # probes.  It would be more efficient if only the models related to each
-        # probe are sent to the probing split.  An option would be to use caching
-        # and allow the ``score`` function above to load the required data from
-        # the disk, directly.  A second option would be to generate named delays
-        # for each model and then associate them here.
-
-        all_references = dask.delayed(list)(biometric_references)
-
-        scores = probe_features.map_partitions(
-            super().score_samples,
-            all_references,
-            allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
-        )
-        return scores
diff --git a/bob/bio/base/pipelines/vanilla_biometrics/pipeline.py b/bob/bio/base/pipelines/vanilla_biometrics/pipelines.py
similarity index 99%
rename from bob/bio/base/pipelines/vanilla_biometrics/pipeline.py
rename to bob/bio/base/pipelines/vanilla_biometrics/pipelines.py
index c5d0cbf4..48ba26d8 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/pipeline.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/pipelines.py
@@ -14,7 +14,7 @@ import numpy
 logger = logging.getLogger(__name__)
 
 
-class VanillaBiometrics(object):
+class VanillaBiometricsPipeline(object):
     """
     Vanilla Biometrics Pipeline
 
@@ -120,7 +120,7 @@ class VanillaBiometrics(object):
         biometric_reference_features = self.transformer.transform(
             biometric_reference_samples
         )
-
+        
         biometric_references = self.biometric_algorithm.enroll_samples(
             biometric_reference_features
         )
@@ -137,7 +137,7 @@ class VanillaBiometrics(object):
 
         # probes is a list of SampleSets
         probe_features = self.transformer.transform(probe_samples)
-
+        
         scores = self.biometric_algorithm.score_samples(
             probe_features,
             biometric_references,
diff --git a/bob/bio/base/test/test_transformers.py b/bob/bio/base/test/test_transformers.py
index c78d5fab..3360d184 100644
--- a/bob/bio/base/test/test_transformers.py
+++ b/bob/bio/base/test/test_transformers.py
@@ -23,17 +23,17 @@ from bob.bio.base.wrappers import (
 from sklearn.pipeline import make_pipeline
 
 
-class _FakePreprocesor(Preprocessor):
+class FakePreprocesor(Preprocessor):
     def __call__(self, data, annotations=None):
         return data + annotations
 
 
-class _FakeExtractor(Extractor):
+class FakeExtractor(Extractor):
     def __call__(self, data, metadata=None):
         return data.flatten()
 
 
-class _FakeExtractorFittable(Extractor):
+class FakeExtractorFittable(Extractor):
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
         self.requires_training = True
@@ -47,7 +47,7 @@ class _FakeExtractorFittable(Extractor):
         bob.io.base.save(self.model, extractor_file)
 
 
-class _FakeAlgorithm(Algorithm):
+class FakeAlgorithm(Algorithm):
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
         self.requires_training = True
@@ -103,7 +103,7 @@ def assert_checkpoints(transformed_sample, dir_name):
 
 def test_preprocessor():
 
-    preprocessor = _FakePreprocesor()
+    preprocessor = FakePreprocesor()
     preprocessor_transformer = PreprocessorTransformer(preprocessor)
 
     # Testing sample
@@ -136,7 +136,7 @@ def test_preprocessor():
 
 def test_extractor():
 
-    extractor = _FakeExtractor()
+    extractor = FakeExtractor()
     extractor_transformer = ExtractorTransformer(extractor)
 
     # Testing sample
@@ -168,7 +168,7 @@ def test_extractor_fittable():
     with tempfile.TemporaryDirectory() as dir_name:
 
         extractor_file = os.path.join(dir_name, "Extractor.hdf5")
-        extractor = _FakeExtractorFittable()
+        extractor = FakeExtractorFittable()
         extractor_transformer = ExtractorTransformer(
             extractor, model_path=extractor_file
         )
@@ -207,7 +207,7 @@ def test_algorithm():
         projector_file = os.path.join(dir_name, "Projector.hdf5")
         projector_pkl = os.path.join(dir_name, "Projector.pkl")  # Testing pickling
 
-        algorithm = _FakeAlgorithm()
+        algorithm = FakeAlgorithm()
         algorithm_transformer = AlgorithmTransformer(
             algorithm, projector_file=projector_file
         )
@@ -258,13 +258,13 @@ def test_wrap_bob_pipeline():
 
             pipeline = make_pipeline(
                 wrap_transform_bob(
-                    _FakePreprocesor(),
+                    FakePreprocesor(),
                     dir_name,
                     transform_extra_arguments=(("annotations", "annotations"),),
                 ),
-                wrap_transform_bob(_FakeExtractor(), dir_name,),
+                wrap_transform_bob(FakeExtractor(), dir_name,),
                 wrap_transform_bob(
-                    _FakeAlgorithm(), dir_name, fit_extra_arguments=(("y", "subject"),)
+                    FakeAlgorithm(), dir_name
                 ),
             )
             oracle = [7.0, 7.0, 7.0, 7.0]
diff --git a/bob/bio/base/test/test_vanilla_biometrics.py b/bob/bio/base/test/test_vanilla_biometrics.py
index 75d93548..9acd489d 100644
--- a/bob/bio/base/test/test_vanilla_biometrics.py
+++ b/bob/bio/base/test/test_vanilla_biometrics.py
@@ -2,33 +2,51 @@
 # vim: set fileencoding=utf-8 :
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
 
-from bob.pipelines.sample import Sample, SampleSet, DelayedSample
+from bob.pipelines import Sample, SampleSet, DelayedSample
 import os
 import numpy
 import tempfile
-from sklearn.utils.validation import check_is_fitted
+from sklearn.pipeline import make_pipeline
+from bob.bio.base.wrappers import wrap_transform_bob
+from bob.bio.base.test.test_transformers import FakePreprocesor, FakeExtractor
+from bob.bio.base.pipelines.vanilla_biometrics import (
+    Distance,
+    VanillaBiometricsPipeline,
+)
+from bob.bio.base.pipelines.vanilla_biometrics import (
+    BioAlgorithmCheckpointWrapper,
+    FourColumnsScoreWriter,
+)
+import uuid
 
 
 class DummyDatabase:
-
-    def __init__(self, delayed=False, n_references=10, n_probes=10, dim=10, one_d = True):
+    def __init__(self, delayed=False, n_references=10, n_probes=10, dim=10, one_d=True):
         self.delayed = delayed
         self.dim = dim
         self.n_references = n_references
         self.n_probes = n_probes
         self.one_d = one_d
 
-
     def _create_random_1dsamples(self, n_samples, offset, dim):
-        return [ Sample(numpy.random.rand(dim), key=i) for i in range(offset,offset+n_samples) ]
+        return [
+            Sample(numpy.random.rand(dim), key=str(uuid.uuid4()), annotations=1)
+            for i in range(offset, offset + n_samples)
+        ]
 
     def _create_random_2dsamples(self, n_samples, offset, dim):
-        return [ Sample(numpy.random.rand(dim, dim), key=i) for i in range(offset,offset+n_samples) ]
+        return [
+            Sample(numpy.random.rand(dim, dim), key=str(uuid.uuid4()), annotations=1)
+            for i in range(offset, offset + n_samples)
+        ]
 
     def _create_random_sample_set(self, n_sample_set=10, n_samples=2):
 
         # Just generate random samples
-        sample_set = [SampleSet(samples=[], key=i) for i in range(n_sample_set)]
+        sample_set = [
+            SampleSet(samples=[], key=str(i), subject=str(i))
+            for i in range(n_sample_set)
+        ]
 
         offset = 0
         for s in sample_set:
@@ -42,41 +60,89 @@ class DummyDatabase:
 
         return sample_set
 
-
     def background_model_samples(self):
         return self._create_random_sample_set()
 
-
     def references(self):
         return self._create_random_sample_set(self.n_references, self.dim)
 
-
     def probes(self):
-        probes = self._create_random_sample_set(self.n_probes, self.dim)
+        probes = []
+
+        probes = self._create_random_sample_set(n_sample_set=10, n_samples=1)
         for p in probes:
             p.references = list(range(self.n_references))
+
         return probes
 
+    @property
+    def allow_scoring_with_all_biometric_references(self):
+        return True
+
+
+def _make_transformer(dir_name):
+    return make_pipeline(
+        wrap_transform_bob(
+            FakePreprocesor(),
+            dir_name,
+            transform_extra_arguments=(("annotations", "annotations"),),
+        ),
+        wrap_transform_bob(FakeExtractor(), dir_name,),
+    )
+
+
+def test_on_memory():
+
+    with tempfile.TemporaryDirectory() as dir_name:
+        database = DummyDatabase()
+
+        transformer = _make_transformer(dir_name)
+
+        biometric_algorithm = Distance()
+
+        biometric_pipeline = VanillaBiometricsPipeline(transformer, biometric_algorithm)
+
+        scores = biometric_pipeline(
+            database.background_model_samples(),
+            database.references(),
+            database.probes(),
+            allow_scoring_with_all_biometric_references=database.allow_scoring_with_all_biometric_references,
+        )
+
+        assert len(scores) == 10
+        for probe_ssets in scores:
+            for probe in probe_ssets:                
+                assert len(probe) == 10
+
+def test_checkpoint():
+
+    with tempfile.TemporaryDirectory() as dir_name:
+
+        def run_pipeline(with_dask):
+            database = DummyDatabase()
+
+            transformer = _make_transformer(dir_name)
+
+            biometric_algorithm = BioAlgorithmCheckpointWrapper(
+                Distance(), base_dir=dir_name
+            )
 
-from bob.bio.base.pipelines.vanilla_biometrics.biometric_algorithm import Distance
-import itertools
-def test_distance_comparator():
+            biometric_pipeline = VanillaBiometricsPipeline(transformer, biometric_algorithm)
 
-    n_references = 10
-    dim = 10
-    n_probes = 10
-    database = DummyDatabase(delayed=False, n_references=n_references, n_probes=n_probes, dim=10, one_d = True)
-    references = database.references()
-    probes = database.probes()
+            scores = biometric_pipeline(
+                database.background_model_samples(),
+                database.references(),
+                database.probes(),
+                allow_scoring_with_all_biometric_references=database.allow_scoring_with_all_biometric_references,
+            )
 
-    comparator = Distance()
-    references = comparator.enroll_samples(references)
-    assert len(references)== n_references
-    assert references[0].data.shape == (dim,)
+            filename = os.path.join(dir_name, "concatenated_scores.txt")
+            FourColumnsScoreWriter().concatenate_write_scores(
+                scores, filename
+            )
+            
+            assert len(open(filename).readlines())==100
 
-    probes = database.probes()
-    scores = comparator.score_samples(probes, references)
-    scores = list(itertools.chain(*scores))
+        run_pipeline(False)
+        run_pipeline(False) # Checking if the checkpoints work
 
-    assert len(scores) == n_probes*n_references
-    assert len(scores[0].samples)==n_references
diff --git a/bob/bio/base/transformers/algorithm.py b/bob/bio/base/transformers/algorithm.py
index f49b906c..6e74a306 100644
--- a/bob/bio/base/transformers/algorithm.py
+++ b/bob/bio/base/transformers/algorithm.py
@@ -7,6 +7,7 @@ from bob.pipelines.utils import is_picklable
 from . import split_X_by_y
 import os
 
+
 class AlgorithmTransformer(TransformerMixin, BaseEstimator):
     """Class that wraps :any:`bob.bio.base.algorithm.Algoritm`
 
@@ -56,8 +57,8 @@ class AlgorithmTransformer(TransformerMixin, BaseEstimator):
 
     def fit(self, X, y=None):
         if not self.callable.requires_training:
-            return self        
-        training_data = X        
+            return self
+        training_data = X
         if self.callable.split_training_features_by_client:
             training_data = split_X_by_y(X, y)
 
@@ -65,7 +66,7 @@ class AlgorithmTransformer(TransformerMixin, BaseEstimator):
         self.callable.train_projector(training_data, self.projector_file)
         return self
 
-    def transform(self, X, metadata=None):        
+    def transform(self, X, metadata=None):
         if metadata is None:
             return [self.callable.project(data) for data in X]
         else:
@@ -75,4 +76,7 @@ class AlgorithmTransformer(TransformerMixin, BaseEstimator):
             ]
 
     def _more_tags(self):
-        return {"stateless": not self.callable.requires_training, "requires_fit": self.callable.requires_training}
+        return {
+            "stateless": not self.callable.requires_training,
+            "requires_fit": self.callable.requires_training,
+        }
diff --git a/bob/bio/base/transformers/extractor.py b/bob/bio/base/transformers/extractor.py
index 40778e77..ab28508f 100644
--- a/bob/bio/base/transformers/extractor.py
+++ b/bob/bio/base/transformers/extractor.py
@@ -5,6 +5,7 @@ from sklearn.base import TransformerMixin, BaseEstimator
 from bob.bio.base.extractor import Extractor
 from . import split_X_by_y
 
+
 class ExtractorTransformer(TransformerMixin, BaseEstimator):
     """
     Scikit learn transformer for :any:`bob.bio.base.extractor.Extractor`.
@@ -29,7 +30,7 @@ class ExtractorTransformer(TransformerMixin, BaseEstimator):
                 "`callable` should be an instance of `bob.bio.base.extractor.Extractor`"
             )
 
-        if callable.requires_training and (model_path is None or model_path==""):
+        if callable.requires_training and (model_path is None or model_path == ""):
             raise ValueError(
                 f"`model_path` needs to be set if extractor {callable} requires training"
             )
diff --git a/bob/bio/base/wrappers.py b/bob/bio/base/wrappers.py
index 2142c231..37499e3c 100644
--- a/bob/bio/base/wrappers.py
+++ b/bob/bio/base/wrappers.py
@@ -13,9 +13,11 @@ import bob.pipelines as mario
 import os
 
 
-
 def wrap_transform_bob(
-    bob_object, dir_name, fit_extra_arguments=None, transform_extra_arguments=None
+    bob_object,
+    dir_name,
+    fit_extra_arguments=(("y", "subject"),),
+    transform_extra_arguments=None,
 ):
     """
     Wraps either :any:`bob.bio.base.preprocessor.Preprocessor`, :any:`bob.bio.base.extractor.Extractor`
-- 
GitLab