diff --git a/bob/bio/base/pipelines/vanilla_biometrics/legacy.py b/bob/bio/base/pipelines/vanilla_biometrics/legacy.py index d358dad3e207cddebfac90a74cf21f454a0ecb6c..9eec5a2a75323bf7153cae8d120fd0084876a974 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/legacy.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/legacy.py @@ -157,240 +157,6 @@ class DatabaseConnector(Database): return list(probes.values()) -class _NonPickableWrapper: - def __init__(self, callable, **kwargs): - super().__init__(**kwargs) - self.callable = callable - self._instance = None - - @property - def instance(self): - # Input can be a functools.partial or an object - if isinstance(self.callable, functools.partial) and self._instance is None: - self._instance = self.callable() - else: - self._instance = self.callable - - return self._instance - - def __setstate__(self, d): - # Handling unpicklable objects - self.__dict__ = d - - def __getstate__(self): - # Handling unpicklable objects - if isinstance(self.callable, functools.partial): - self._instance = None - return self.__dict__ - - -class _Preprocessor(_NonPickableWrapper, TransformerMixin, BaseEstimator): - def transform(self, X, annotations=None): - if annotations is None: - return [self.instance(data) for data in X] - else: - return [self.instance(data, annot) for data, annot in zip(X, annotations)] - - def _more_tags(self): - return {"stateless": True, "requires_fit": False} - - -def _get_pickable_method(method): - from bob.pipelines.utils import is_picklable - if not is_picklable(method): - logger.warning( - f"The method {method} is not picklable. Returning its unbounded method" - ) - method = functools.partial(method.__func__, None) - return method - - -class Preprocessor(CheckpointMixin, SampleMixin, _Preprocessor): - def __init__( - self, - callable, - transform_extra_arguments=(("annotations", "annotations"),), - **kwargs, - ): - - # Input can be a functools.partial or an object - if isinstance(callable, functools.partial): - instance = callable() - else: - instance = callable - - super().__init__( - callable=callable, - transform_extra_arguments=transform_extra_arguments, - load_func=instance.read_data, - save_func=instance.write_data, - **kwargs, - ) - - -def _split_X_by_y(X, y): - training_data = defaultdict(list) - for x1, y1 in zip(X, y): - training_data[y1].append(x1) - training_data = training_data.values() - return training_data - - -class _Extractor(_NonPickableWrapper, TransformerMixin, BaseEstimator): - def transform(self, X, metadata=None): - if self.requires_metadata: - return [self.instance(data, metadata=m) for data, m in zip(X, metadata)] - else: - return [self.instance(data) for data in X] - - def fit(self, X, y=None): - if not self.instance.requires_training: - return self - - training_data = X - if self.instance.split_training_data_by_client: - training_data = _split_X_by_y(X, y) - - self.instance.train(self, training_data, self.model_path) - return self - - def _more_tags(self): - return { - "requires_fit": self.instance.requires_training, - "stateless": not self.instance.requires_training, - } - - -class Extractor(CheckpointMixin, SampleMixin, _Extractor): - def __init__(self, callable, model_path=None, **kwargs): - # Input can be a functools.partial or an object - if isinstance(callable, functools.partial): - instance = callable() - else: - instance = callable - - transform_extra_arguments = None - self.requires_metadata = False - if utils.is_argument_available("metadata", instance.__call__): - transform_extra_arguments = (("metadata", "metadata"),) - self.requires_metadata = True - - fit_extra_arguments = None - if instance.requires_training and instance.split_training_data_by_client: - fit_extra_arguments = (("y", "subject"),) - - super().__init__( - callable=callable, - transform_extra_arguments=transform_extra_arguments, - fit_extra_arguments=fit_extra_arguments, - load_func=instance.read_feature, - save_func=instance.write_feature, - model_path=model_path, - **kwargs, - ) - - def load_model(self): - self.instance.load(self.model_path) - return self - - def save_model(self): - # we have already saved the model in .fit() - return self - - -class _AlgorithmTransformer(TransformerMixin, BaseEstimator): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_projector_loaded = False - - def transform(self, X): - self._load_projector() - return [self.instance.project(feature) for feature in X] - - def _load_projector(self): - """ - Run :py:meth:`bob.bio.base.algorithm.Algorithm.load_projector` if necessary by - :py:class:`bob.bio.base.algorithm.Algorithm` - """ - if self.instance.performs_projection: - if self.model_path is None: - raise ValueError( - "Algorithm " + f"{self.instance} performs_projection. Hence, " - "`model_path` needs to passed in `AlgorithmAsTransformer.__init__`" - ) - else: - # Loading model - self.instance.load_projector(self.model_path) - - def fit(self, X, y=None): - if not self.instance.requires_projector_training: - return self - - training_data = X - if self.instance.split_training_features_by_client: - training_data = _split_X_by_y(X, y) - - self.instance.train_projector(training_data, self.model_path) - return self - - def _more_tags(self): - return {"requires_fit": self.instance.requires_projector_training} - - -class AlgorithmAsTransformer(CheckpointMixin, SampleMixin, _AlgorithmTransformer): - """Class that wraps :py:class:`bob.bio.base.algorithm.Algoritm` - - :py:method:`LegacyAlgorithmrMixin.fit` maps to :py:method:`bob.bio.base.algorithm.Algoritm.train_projector` - - :py:method:`LegacyAlgorithmrMixin.transform` maps :py:method:`bob.bio.base.algorithm.Algoritm.project` - - Example - ------- - - Wrapping LDA algorithm with functtools - >>> from bob.bio.base.pipelines.vanilla_biometrics.legacy import LegacyAlgorithmAsTransformer - >>> from bob.bio.base.algorithm import LDA - >>> import functools - >>> transformer = LegacyAlgorithmAsTransformer(functools.partial(LDA, use_pinv=True, pca_subspace_dimension=0.90)) - - - - Parameters - ---------- - callable: callable - Calleble function that instantiates the bob.bio.base.algorithm.Algorithm - - """ - - def __init__(self, callable, model_path, **kwargs): - instance = callable() - - fit_extra_arguments = None - if ( - instance.requires_projector_training - and instance.split_training_features_by_client - ): - fit_extra_arguments = (("y", "subject"),) - - super().__init__( - callable=callable, - fit_extra_arguments=fit_extra_arguments, - load_func=_get_pickable_method(instance.read_feature), - save_func=_get_pickable_method(instance.write_feature), - model_path=model_path, - **kwargs, - ) - - def load_model(self): - self.instance.load_projector(self.model_path) - return self - - def save_model(self): - # we have already saved the model in .fit() - return self - - class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm): """Biometric Algorithm that handles legacy :py:class:`bob.bio.base.algorithm.Algorithm` diff --git a/bob/bio/base/pipelines/vanilla_biometrics/pipeline.py b/bob/bio/base/pipelines/vanilla_biometrics/pipeline.py index c5d0cbf487c708d2b8d6a8e5ef01cfecbc8fa897..65db1d3595e79571decf4b6accf4844320395280 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/pipeline.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/pipeline.py @@ -148,6 +148,69 @@ class VanillaBiometrics(object): return scores +class VanillaBiometricsZTNorm(object): + """ + Vanilla Biometrics Pipelines that runs ZT Score Normalization + """ + + def __init__(vanilla_pipeline): + self.vanilla_pipeline = vanilla_pipeline + + def __call__( + self, + background_model_samples, + biometric_reference_samples, + probe_samples, + z_probe_samples, + t_biometric_reference_samples, + allow_scoring_with_all_biometric_references=False, + ): + logger.info( + f" >> Vanilla Biometrics: Training background model with pipeline {self.transformer}" + ) + + # Training background model (fit will return even if samples is ``None``, + # in which case we suppose the algorithm is not trainable in any way) + self.vanilla_pipeline.transformer = self.vanilla_pipeline.train_background_model( + background_model_samples + ) + + logger.info( + f" >> Creating biometric references with the biometric algorithm {self.biometric_algorithm}" + ) + + # Create biometric samples + biometric_references = self.vanilla_pipeline.create_biometric_reference( + biometric_reference_samples + ) + + logger.info( + f" >> Computing scores with the biometric algorithm {self.biometric_algorithm}" + ) + + # Scores all probes + scores = self.vanilla_pipeline.compute_scores( + probe_samples, + biometric_references, + allow_scoring_with_all_biometric_references, + ) + + # Return a list of SampleSets containing the Z-Statistics per + # biometric reference + zstatistics = self.compute_zstatistics( + zprobe_samples, + biometric_references, + allow_scoring_with_all_biometric_references, + ) + z_norm_scores = self.znorm(scores, zstatistics) + + + # Create t-biometric references + t_biometric_references = self.vanilla_pipeline.create_biometric_reference( + t_biometric_reference_samples + ) + + def dask_vanilla_biometrics(pipeline, npartitions=None): """ Given a :py:class:`VanillaBiometrics`, wraps :py:meth:`VanillaBiometrics.transformer` and diff --git a/bob/bio/base/test/test_transformers.py b/bob/bio/base/test/test_transformers.py index 5fa1b098d1f4d92475d570fa098917bd6a5ee6ac..c78d5fab54757ab742d6a56428030074436f1514 100644 --- a/bob/bio/base/test/test_transformers.py +++ b/bob/bio/base/test/test_transformers.py @@ -1,68 +1,285 @@ #!/usr/bin/env python # vim: set fileencoding=utf-8 : -# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -from bob.pipelines.sample import Sample, SampleSet, DelayedSample -import os -import numpy +from bob.bio.base.preprocessor import Preprocessor +from bob.bio.base.extractor import Extractor +from bob.bio.base.algorithm import Algorithm +from bob.bio.base.transformers import ( + PreprocessorTransformer, + ExtractorTransformer, + AlgorithmTransformer, +) +import bob.pipelines as mario +import numpy as np import tempfile -from sklearn.utils.validation import check_is_fitted +import os +import bob.io.base +from bob.bio.base.wrappers import ( + wrap_preprocessor, + wrap_extractor, + wrap_algorithm, + wrap_transform_bob, +) +from sklearn.pipeline import make_pipeline + + +class _FakePreprocesor(Preprocessor): + def __call__(self, data, annotations=None): + return data + annotations + + +class _FakeExtractor(Extractor): + def __call__(self, data, metadata=None): + return data.flatten() + + +class _FakeExtractorFittable(Extractor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.requires_training = True + self.model = None + + def __call__(self, data, metadata=None): + return data @ self.model + + def train(self, training_data, extractor_file): + self.model = training_data + bob.io.base.save(self.model, extractor_file) + + +class _FakeAlgorithm(Algorithm): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.requires_training = True + self.split_training_features_by_client = True + self.model = None + + def project(self, data, metadata=None): + return data + self.model + + def train_projector(self, training_features, projector_file): + self.model = np.sum(np.vstack(training_features), axis=0) + bob.io.base.save(self.model, projector_file) + + def load_projector(self, projector_file): + self.model = bob.io.base.load(projector_file) + + +def generate_samples(n_subjects, n_samples_per_subject, shape=(2, 2), annotations=1): + """ + Simple sample generator that generates a certain number of samples per + subject, whose data is np.zeros + subject index + """ + + samples = [] + for i in range(n_subjects): + data = np.zeros(shape) + i + for j in range(n_samples_per_subject): + samples += [ + mario.Sample( + data, + subject=str(i), + key=str(i * n_subjects + j), + annotations=annotations, + ) + ] + return samples + + +def assert_sample(transformed_sample, oracle): + return np.alltrue( + [np.allclose(ts.data, o) for ts, o in zip(transformed_sample, oracle)] + ) + + +def assert_checkpoints(transformed_sample, dir_name): + return np.alltrue( + [ + os.path.exists(os.path.join(dir_name, ts.key + ".h5")) + for ts in transformed_sample + ] + ) + + +def test_preprocessor(): + + preprocessor = _FakePreprocesor() + preprocessor_transformer = PreprocessorTransformer(preprocessor) + + # Testing sample + transform_extra_arguments = [("annotations", "annotations")] + sample_transformer = mario.SampleWrapper( + preprocessor_transformer, transform_extra_arguments + ) + + data = np.zeros((2, 2)) + oracle = [np.ones((2, 2))] + annotations = 1 + sample = [mario.Sample(data, key="1", annotations=annotations)] + transformed_sample = sample_transformer.transform(sample) + + assert assert_sample(transformed_sample, oracle) + + # Testing checkpoint + with tempfile.TemporaryDirectory() as dir_name: + checkpointing_transformer = mario.CheckpointWrapper( + sample_transformer, + features_dir=dir_name, + load_func=preprocessor.read_data, + save_func=preprocessor.write_data, + ) + transformed_sample = checkpointing_transformer.transform(sample) + + assert assert_sample(transformed_sample, oracle) + assert assert_checkpoints(transformed_sample, dir_name) + + +def test_extractor(): + + extractor = _FakeExtractor() + extractor_transformer = ExtractorTransformer(extractor) + + # Testing sample + sample_transformer = mario.SampleWrapper(extractor_transformer) + + data = np.zeros((2, 2)) + oracle = [np.zeros((1, 4))] + sample = [mario.Sample(data, key="1")] + transformed_sample = sample_transformer.transform(sample) + + assert assert_sample(transformed_sample, oracle) + + # Testing checkpoint + with tempfile.TemporaryDirectory() as dir_name: + checkpointing_transformer = mario.CheckpointWrapper( + sample_transformer, + features_dir=dir_name, + load_func=extractor.read_feature, + save_func=extractor.write_feature, + ) + transformed_sample = checkpointing_transformer.transform(sample) + + assert assert_sample(transformed_sample, oracle) + assert assert_checkpoints(transformed_sample, dir_name) + + +def test_extractor_fittable(): + + with tempfile.TemporaryDirectory() as dir_name: + + extractor_file = os.path.join(dir_name, "Extractor.hdf5") + extractor = _FakeExtractorFittable() + extractor_transformer = ExtractorTransformer( + extractor, model_path=extractor_file + ) + + # Testing sample + sample_transformer = mario.SampleWrapper(extractor_transformer) + + # Fitting + training_data = np.arange(4).reshape(2, 2) + training_samples = [mario.Sample(training_data, key="1")] + sample_transformer = sample_transformer.fit(training_samples) + + test_data = [np.zeros((2, 2)), np.ones((2, 2))] + oracle = [np.zeros((2, 2)), np.ones((2, 2)) @ training_data] + test_sample = [mario.Sample(d, key=str(i)) for i, d in enumerate(test_data)] + + transformed_sample = sample_transformer.transform(test_sample) + assert assert_sample(transformed_sample, oracle) + + # Testing checkpoint + checkpointing_transformer = mario.CheckpointWrapper( + sample_transformer, + features_dir=dir_name, + load_func=extractor.read_feature, + save_func=extractor.write_feature, + ) + transformed_sample = checkpointing_transformer.transform(test_sample) + assert assert_sample(transformed_sample, oracle) + assert assert_checkpoints(transformed_sample, dir_name) -from bob.pipelines.transformers import Linearize, SampleLinearize, CheckpointSampleLinearize -def test_linearize_processor(): - ## Test the transformer only - transformer = Linearize() - X = numpy.zeros(shape=(10,10)) - X_tr = transformer.transform(X) - assert X_tr.shape == (100,) +def test_algorithm(): + with tempfile.TemporaryDirectory() as dir_name: - ## Test wrapped in to a Sample - sample = Sample(X, key="1") - transformer = SampleLinearize() - X_tr = transformer.transform([sample]) - assert X_tr[0].data.shape == (100,) + projector_file = os.path.join(dir_name, "Projector.hdf5") + projector_pkl = os.path.join(dir_name, "Projector.pkl") # Testing pickling - ## Test checkpoint - with tempfile.TemporaryDirectory() as d: - transformer = CheckpointSampleLinearize(features_dir=d) - X_tr = transformer.transform([sample]) - assert X_tr[0].data.shape == (100,) - assert os.path.exists(os.path.join(d, "1.h5")) + algorithm = _FakeAlgorithm() + algorithm_transformer = AlgorithmTransformer( + algorithm, projector_file=projector_file + ) + # Testing sample + fit_extra_arguments = [("y", "subject")] + sample_transformer = mario.SampleWrapper( + algorithm_transformer, fit_extra_arguments=fit_extra_arguments + ) -from bob.pipelines.transformers import SamplePCA, CheckpointSamplePCA -def test_pca_processor(): + n_subjects = 2 + n_samples_per_subject = 2 + shape = (2, 2) + training_samples = generate_samples( + n_subjects, n_samples_per_subject, shape=shape + ) + sample_transformer = sample_transformer.fit(training_samples) - ## Test wrapped in to a Sample - X = numpy.random.rand(100,10) - samples = [Sample(data, key=str(i)) for i, data in enumerate(X)] + oracle = np.zeros(shape) + n_subjects + test_sample = generate_samples(1, 1) + transformed_sample = sample_transformer.transform(test_sample) + assert assert_sample(transformed_sample, [oracle]) + assert os.path.exists(projector_file) - # fit - n_components = 2 - estimator = SamplePCA(n_components=n_components) - estimator = estimator.fit(samples) + # Testing checkpoint + checkpointing_transformer = mario.CheckpointWrapper( + sample_transformer, + features_dir=dir_name, + load_func=algorithm.read_feature, + save_func=algorithm.write_feature, + model_path=projector_pkl, + ) + # Fitting again to assert if it loads again + checkpointing_transformer = checkpointing_transformer.fit(training_samples) + transformed_sample = checkpointing_transformer.transform(test_sample) - # https://scikit-learn.org/stable/modules/generated/sklearn.utils.validation.check_is_fitted.html - assert check_is_fitted(estimator, "n_components_") is None + # Fitting again + assert assert_sample(transformed_sample, oracle) + transformed_sample = checkpointing_transformer.transform(test_sample) + assert assert_checkpoints(transformed_sample, dir_name) + assert os.path.exists(projector_pkl) - # transform - samples_tr = estimator.transform(samples) - assert samples_tr[0].data.shape == (n_components,) +def test_wrap_bob_pipeline(): - ## Test Checkpoining - with tempfile.TemporaryDirectory() as d: - model_path = os.path.join(d, "model.pkl") - estimator = CheckpointSamplePCA(n_components=n_components, features_dir=d, model_path=model_path) + def run_pipeline(with_dask): + with tempfile.TemporaryDirectory() as dir_name: - # fit - estimator = estimator.fit(samples) - assert check_is_fitted(estimator, "n_components_") is None - assert os.path.exists(model_path) + pipeline = make_pipeline( + wrap_transform_bob( + _FakePreprocesor(), + dir_name, + transform_extra_arguments=(("annotations", "annotations"),), + ), + wrap_transform_bob(_FakeExtractor(), dir_name,), + wrap_transform_bob( + _FakeAlgorithm(), dir_name, fit_extra_arguments=(("y", "subject"),) + ), + ) + oracle = [7.0, 7.0, 7.0, 7.0] + training_samples = generate_samples(n_subjects=2, n_samples_per_subject=2) + test_samples = generate_samples(n_subjects=1, n_samples_per_subject=1) + if with_dask: + pipeline = mario.wrap(["dask"], pipeline) + transformed_samples = ( + pipeline.fit(training_samples).transform(test_samples).compute() + ) + else: + transformed_samples = pipeline.fit(training_samples).transform( + test_samples + ) + assert assert_sample(transformed_samples, oracle) - # transform - samples_tr = estimator.transform(samples) - assert samples_tr[0].data.shape == (n_components,) - assert os.path.exists(os.path.join(d, samples_tr[0].key+".h5")) + run_pipeline(False) + run_pipeline(True) diff --git a/bob/bio/base/transformers/__init__.py b/bob/bio/base/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4893224a141f179011ceea7b35e2df51af3d94c4 --- /dev/null +++ b/bob/bio/base/transformers/__init__.py @@ -0,0 +1,18 @@ +# see https://docs.python.org/3/library/pkgutil.html +from pkgutil import extend_path + +__path__ = extend_path(__path__, __name__) + +from collections import defaultdict +def split_X_by_y(X, y): + training_data = defaultdict(list) + for x1, y1 in zip(X, y): + training_data[y1].append(x1) + training_data = list(training_data.values()) + return training_data + + + +from .preprocessor import PreprocessorTransformer +from .extractor import ExtractorTransformer +from .algorithm import AlgorithmTransformer diff --git a/bob/bio/base/transformers/algorithm.py b/bob/bio/base/transformers/algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..1fe164d0203486dc6a9e40159fd6df72b55ecdac --- /dev/null +++ b/bob/bio/base/transformers/algorithm.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +from sklearn.base import TransformerMixin, BaseEstimator +from bob.bio.base.algorithm import Algorithm +from bob.pipelines.utils import is_picklable +from . import split_X_by_y +import os + +class AlgorithmTransformer(TransformerMixin, BaseEstimator): + """Class that wraps :any:`bob.bio.base.algorithm.Algoritm` + + :any:`LegacyAlgorithmrMixin.fit` maps to :any:`bob.bio.base.algorithm.Algoritm.train_projector` + + :any:`LegacyAlgorithmrMixin.transform` maps :any:`bob.bio.base.algorithm.Algoritm.project` + + Example + ------- + + Wrapping LDA algorithm with functtools + >>> from bob.bio.base.pipelines.vanilla_biometrics.legacy import LegacyAlgorithmAsTransformer + >>> from bob.bio.base.algorithm import LDA + >>> import functools + >>> transformer = LegacyAlgorithmAsTransformer(functools.partial(LDA, use_pinv=True, pca_subspace_dimension=0.90)) + + + Parameters + ---------- + callable: ``collections.callable`` + Callable function that instantiates the bob.bio.base.algorithm.Algorithm + + """ + + def __init__( + self, callable, projector_file=None, **kwargs, + ): + + if not isinstance(callable, Algorithm): + raise ValueError( + "`callable` should be an instance of `bob.bio.base.extractor.Algorithm`" + ) + + if callable.requires_training and ( + projector_file is None or projector_file == "" + ): + raise ValueError( + f"`projector_file` needs to be set if extractor {callable} requires training" + ) + + if not is_picklable(callable): + raise ValueError(f"{callable} needs to be picklable") + + self.callable = callable + self.projector_file = projector_file + super().__init__(**kwargs) + + def fit(self, X, y=None): + if not self.callable.requires_training: + return self + training_data = X + if self.callable.split_training_features_by_client: + training_data = split_X_by_y(X, y) + + os.makedirs(os.path.dirname(self.projector_file), exist_ok=True) + self.callable.train_projector(training_data, self.projector_file) + return self + + def transform(self, X, metadata=None): + if metadata is None: + return [self.callable.project(data) for data in X] + else: + return [ + self.callable.project(data, metadata) + for data, metadata in zip(X, metadata) + ] + + def _more_tags(self): + if self.callable.requires_training: + return {"stateless": False, "requires_fit": True} + else: + return {"stateless": True, "requires_fit": False} diff --git a/bob/bio/base/transformers/extractor.py b/bob/bio/base/transformers/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ddcb327d4efd24984c50070824bf8008737a0d --- /dev/null +++ b/bob/bio/base/transformers/extractor.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +from sklearn.base import TransformerMixin, BaseEstimator +from bob.bio.base.extractor import Extractor +from . import split_X_by_y + +class ExtractorTransformer(TransformerMixin, BaseEstimator): + """ + Scikit learn transformer for :any:`bob.bio.base.extractor.Extractor`. + + Parameters + ---------- + + callable: ``collections.Callable`` + Instance of `bob.bio.base.extractor.Extractor` + + model_path: ``str`` + Model path in case :any:`bob.bio.base.extractor.Extractor.requires_training` is equals to true + + """ + + def __init__( + self, callable, model_path=None, **kwargs, + ): + + if not isinstance(callable, Extractor): + raise ValueError( + "`callable` should be an instance of `bob.bio.base.extractor.Extractor`" + ) + + if callable.requires_training and (model_path is None or model_path==""): + raise ValueError( + f"`model_path` needs to be set if extractor {callable} requires training" + ) + + self.callable = callable + self.model_path = model_path + super().__init__(**kwargs) + + def fit(self, X, y=None): + if not self.callable.requires_training: + return self + + training_data = X + if self.callable.split_training_data_by_client: + training_data = split_X_by_y(X, y) + + self.callable.train(training_data, self.model_path) + return self + + def transform(self, X, metadata=None): + if metadata is None: + return [self.callable(data) for data in X] + else: + return [ + self.callable(data, metadata) for data, metadata in zip(X, metadata) + ] + + def _more_tags(self): + if self.callable.requires_training: + return {"stateless": False, "requires_fit": True} + else: + return {"stateless": True, "requires_fit": False} diff --git a/bob/bio/base/transformers/preprocessor.py b/bob/bio/base/transformers/preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..1cddda6713117cceb5cb0035b77dde624c0145db --- /dev/null +++ b/bob/bio/base/transformers/preprocessor.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +from sklearn.base import TransformerMixin, BaseEstimator +from bob.bio.base.preprocessor import Preprocessor + +class PreprocessorTransformer(TransformerMixin, BaseEstimator): + """ + Scikit learn transformer for :any:`bob.bio.base.preprocessor.Preprocessor`. + + Parameters + ---------- + + callable: ``collections.Callable`` + Instance of `bob.bio.base.preprocessor.Preprocessor` + + + """ + + def __init__( + self, + callable, + **kwargs, + ): + + if not isinstance(callable, Preprocessor): + raise ValueError("`callable` should be an instance of `bob.bio.base.preprocessor.Preprocessor`") + + self.callable = callable + super().__init__(**kwargs) + + def transform(self, X, annotations=None): + if annotations is None: + return [self.callable(data) for data in X] + else: + return [self.callable(data, annot) for data, annot in zip(X, annotations)] + + def _more_tags(self): + return {"stateless": True, "requires_fit": False} diff --git a/bob/bio/base/wrappers.py b/bob/bio/base/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..2142c23146085c88194b0f9c65653a5d007a147b --- /dev/null +++ b/bob/bio/base/wrappers.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +from bob.bio.base.transformers import ( + PreprocessorTransformer, + ExtractorTransformer, + AlgorithmTransformer, +) +from bob.bio.base.preprocessor import Preprocessor +from bob.bio.base.extractor import Extractor +from bob.bio.base.algorithm import Algorithm +import bob.pipelines as mario +import os + + + +def wrap_transform_bob( + bob_object, dir_name, fit_extra_arguments=None, transform_extra_arguments=None +): + """ + Wraps either :any:`bob.bio.base.preprocessor.Preprocessor`, :any:`bob.bio.base.extractor.Extractor` + or :any:`bob.bio.base.algorithm.Algorithm` with :any:`sklearn.base.TransformerMixin` + and :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` + + + Parameters + ---------- + + bob_object: object + Instance of :any:`bob.bio.base.preprocessor.Preprocessor`, :any:`bob.bio.base.extractor.Extractor` and :any:`bob.bio.base.algorithm.Algorithm` + + dir_name: str + Directory name for the checkpoints + + fit_extra_arguments: [tuple] + Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments` + + transform_extra_arguments: [tuple] + Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` + + + """ + + if isinstance(bob_object, Preprocessor): + preprocessor_transformer = PreprocessorTransformer(bob_object) + return wrap_preprocessor( + preprocessor_transformer, + features_dir=os.path.join(dir_name, "preprocessor"), + transform_extra_arguments=transform_extra_arguments, + ) + elif isinstance(bob_object, Extractor): + extractor_transformer = ExtractorTransformer(bob_object) + path = os.path.join(dir_name, "extractor") + return wrap_extractor( + extractor_transformer, + features_dir=path, + model_path=os.path.join(path, "extractor.pkl"), + transform_extra_arguments=transform_extra_arguments, + fit_extra_arguments=fit_extra_arguments, + ) + elif isinstance(bob_object, Algorithm): + path = os.path.join(dir_name, "algorithm") + algorithm_transformer = AlgorithmTransformer( + bob_object, projector_file=os.path.join(path, "Projector.hdf5") + ) + return wrap_algorithm( + algorithm_transformer, + features_dir=path, + model_path=os.path.join(path, "algorithm.pkl"), + transform_extra_arguments=transform_extra_arguments, + fit_extra_arguments=fit_extra_arguments, + ) + else: + raise ValueError( + "`bob_object` should be an instance of `Preprocessor`, `Extractor` and `Algorithm`" + ) + + +def wrap_preprocessor( + preprocessor_transformer, features_dir=None, transform_extra_arguments=None, +): + """ + Wraps :any:`bob.bio.base.transformers.PreprocessorTransformer` with + :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` + + Parameters + ---------- + + preprocessor_transformer: :any:`bob.bio.base.transformers.PreprocessorTransformer` + Instance of :any:`bob.bio.base.transformers.PreprocessorTransformer` to be wrapped + + features_dir: str + Features directory to be checkpointed + + transform_extra_arguments: [tuple] + Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` + + """ + + if not isinstance(preprocessor_transformer, PreprocessorTransformer): + raise ValueError( + f"Expected an instance of PreprocessorTransformer, not {preprocessor_transformer}" + ) + + return mario.wrap( + ["sample", "checkpoint"], + preprocessor_transformer, + load_func=preprocessor_transformer.callable.read_data, + save_func=preprocessor_transformer.callable.write_data, + features_dir=features_dir, + transform_extra_arguments=transform_extra_arguments, + ) + + +def wrap_extractor( + extractor_transformer, + fit_extra_arguments=None, + transform_extra_arguments=None, + features_dir=None, + model_path=None, +): + """ + Wraps :any:`bob.bio.base.transformers.ExtractorTransformer` with + :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` + + Parameters + ---------- + + extractor_transformer: :any:`bob.bio.base.transformers.ExtractorTransformer` + Instance of :any:`bob.bio.base.transformers.ExtractorTransformer` to be wrapped + + features_dir: str + Features directory to be checkpointed + + model_path: str + Path to checkpoint the model + + fit_extra_arguments: [tuple] + Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments` + + transform_extra_arguments: [tuple] + Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` + + """ + + if not isinstance(extractor_transformer, ExtractorTransformer): + raise ValueError( + f"Expected an instance of ExtractorTransformer, not {extractor_transformer}" + ) + + return mario.wrap( + ["sample", "checkpoint"], + extractor_transformer, + load_func=extractor_transformer.callable.read_feature, + save_func=extractor_transformer.callable.write_feature, + model_path=model_path, + features_dir=features_dir, + transform_extra_arguments=transform_extra_arguments, + fit_extra_arguments=fit_extra_arguments, + ) + + +def wrap_algorithm( + algorithm_transformer, + fit_extra_arguments=None, + transform_extra_arguments=None, + features_dir=None, + model_path=None, +): + """ + Wraps :any:`bob.bio.base.transformers.AlgorithmTransformer` with + :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` + + Parameters + ---------- + + algorithm_transformer: :any:`bob.bio.base.transformers.AlgorithmTransformer` + Instance of :any:`bob.bio.base.transformers.AlgorithmTransformer` to be wrapped + + features_dir: str + Features directory to be checkpointed + + model_path: str + Path to checkpoint the model + + fit_extra_arguments: [tuple] + Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments` + + transform_extra_arguments: [tuple] + Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` + + """ + + if not isinstance(algorithm_transformer, AlgorithmTransformer): + raise ValueError( + f"Expected an instance of AlgorithmTransformer, not {algorithm_transformer}" + ) + + return mario.wrap( + ["sample", "checkpoint"], + algorithm_transformer, + load_func=algorithm_transformer.callable.read_feature, + save_func=algorithm_transformer.callable.write_feature, + model_path=model_path, + features_dir=features_dir, + transform_extra_arguments=transform_extra_arguments, + fit_extra_arguments=fit_extra_arguments, + )