diff --git a/bob/bio/base/test/test_transformers.py b/bob/bio/base/test/test_transformers.py index e78049d736482e56f0afdd3e000b35853b6810d4..0bd43dcab8ab5e655b346ba893a3b845553bffa3 100644 --- a/bob/bio/base/test/test_transformers.py +++ b/bob/bio/base/test/test_transformers.py @@ -16,10 +16,12 @@ import tempfile import os import bob.io.base from bob.bio.base.wrappers import ( - wrap_preprocessor, - wrap_extractor, - wrap_algorithm, - wrap_bob_legacy, + wrap_checkpoint_preprocessor, + wrap_checkpoint_extractor, + wrap_checkpoint_algorithm, + wrap_sample_preprocessor, + wrap_sample_extractor, + wrap_sample_algorithm, ) from sklearn.pipeline import make_pipeline @@ -30,7 +32,7 @@ class FakePreprocesor(Preprocessor): class FakeExtractor(Extractor): - def __call__(self, data, metadata=None): + def __call__(self, data): return data.flatten() @@ -56,7 +58,7 @@ class FakeAlgorithm(Algorithm): self.split_training_features_by_client = True self.model = None - def project(self, data, metadata=None): + def project(self, data): return data + self.model def train_projector(self, training_features, projector_file): @@ -259,25 +261,30 @@ def test_algorithm(): def test_wrap_bob_pipeline(): - def run_pipeline(with_dask): + def run_pipeline(with_dask, with_checkpoint): with tempfile.TemporaryDirectory() as dir_name: + if with_checkpoint: + pipeline = make_pipeline( + wrap_checkpoint_preprocessor(FakePreprocesor(), dir_name,), + wrap_checkpoint_extractor(FakeExtractor(), dir_name,), + wrap_checkpoint_algorithm(FakeAlgorithm(), dir_name), + ) + else: + pipeline = make_pipeline( + wrap_sample_preprocessor(FakePreprocesor()), + wrap_sample_extractor(FakeExtractor(), dir_name,), + wrap_sample_algorithm(FakeAlgorithm(), dir_name), + ) - pipeline = make_pipeline( - wrap_bob_legacy( - FakePreprocesor(), - dir_name, - transform_extra_arguments=(("annotations", "annotations"),), - ), - wrap_bob_legacy(FakeExtractor(), dir_name,), - wrap_bob_legacy(FakeAlgorithm(), dir_name), - ) oracle = [7.0, 7.0, 7.0, 7.0] training_samples = generate_samples(n_subjects=2, n_samples_per_subject=2) test_samples = generate_samples(n_subjects=1, n_samples_per_subject=1) if with_dask: pipeline = mario.wrap(["dask"], pipeline) transformed_samples = ( - pipeline.fit(training_samples).transform(test_samples).compute(scheduler="single-threaded") + pipeline.fit(training_samples) + .transform(test_samples) + .compute(scheduler="single-threaded") ) else: transformed_samples = pipeline.fit(training_samples).transform( @@ -285,5 +292,7 @@ def test_wrap_bob_pipeline(): ) assert assert_sample(transformed_samples, oracle) - run_pipeline(False) - run_pipeline(True) + run_pipeline(False, False) + run_pipeline(False, True) + run_pipeline(True, False) + run_pipeline(True, True) diff --git a/bob/bio/base/test/test_vanilla_biometrics.py b/bob/bio/base/test/test_vanilla_biometrics.py index ff7d7c4da706b589a3873764b8a2db512a6ad9b0..ff525ab9e06c428852da76ab889af0b3a3cbff14 100644 --- a/bob/bio/base/test/test_vanilla_biometrics.py +++ b/bob/bio/base/test/test_vanilla_biometrics.py @@ -8,7 +8,11 @@ import numpy as np import tempfile from sklearn.pipeline import make_pipeline from bob.bio.base.wrappers import wrap_bob_legacy -from bob.bio.base.test.test_transformers import FakePreprocesor, FakeExtractor, FakeAlgorithm +from bob.bio.base.test.test_transformers import ( + FakePreprocesor, + FakeExtractor, + FakeAlgorithm, +) from bob.bio.base.pipelines.vanilla_biometrics import ( Distance, VanillaBiometricsPipeline, @@ -16,7 +20,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( dask_vanilla_biometrics, FourColumnsScoreWriter, CSVScoreWriter, - BioAlgorithmLegacy + BioAlgorithmLegacy, ) import bob.pipelines as mario @@ -24,6 +28,7 @@ import uuid import shutil import itertools + class DummyDatabase: def __init__(self, delayed=False, n_references=10, n_probes=10, dim=10, one_d=True): self.delayed = delayed @@ -36,13 +41,23 @@ class DummyDatabase: def _create_random_1dsamples(self, n_samples, offset, dim): return [ - Sample(np.random.rand(dim), key=str(uuid.uuid4()), annotations=1, subject=str(i)) + Sample( + np.random.rand(dim), + key=str(uuid.uuid4()), + annotations=1, + subject=str(i), + ) for i in range(offset, offset + n_samples) ] def _create_random_2dsamples(self, n_samples, offset, dim): return [ - Sample(np.random.rand(dim, dim), key=str(uuid.uuid4()), annotations=1, subject=str(i)) + Sample( + np.random.rand(dim, dim), + key=str(uuid.uuid4()), + annotations=1, + subject=str(i), + ) for i in range(offset, offset + n_samples) ] @@ -74,7 +89,7 @@ class DummyDatabase: return sample_set def background_model_samples(self): - samples = [sset.samples for sset in self._create_random_sample_set()] + samples = [sset.samples for sset in self._create_random_sample_set()] return list(itertools.chain(*samples)) def references(self): @@ -101,11 +116,12 @@ def _make_transformer(dir_name): dir_name, transform_extra_arguments=(("annotations", "annotations"),), ), - wrap_bob_legacy(FakeExtractor(), dir_name,) + wrap_bob_legacy(FakeExtractor(), dir_name,), ) return pipeline + def _make_transformer_with_algorithm(dir_name): pipeline = make_pipeline( wrap_bob_legacy( @@ -114,7 +130,7 @@ def _make_transformer_with_algorithm(dir_name): transform_extra_arguments=(("annotations", "annotations"),), ), wrap_bob_legacy(FakeExtractor(), dir_name), - wrap_bob_legacy(FakeAlgorithm(), dir_name) + wrap_bob_legacy(FakeAlgorithm(), dir_name), ) return pipeline @@ -197,7 +213,9 @@ def test_checkpoint_bioalg_as_transformer(): if isinstance(score_writer, CSVScoreWriter): base_path = os.path.join(dir_name, "concatenated_scores") score_writer.concatenate_write_scores(scores, base_path) - assert len(open(os.path.join(base_path, "chunk_0.csv")).readlines()) == 101 + assert ( + len(open(os.path.join(base_path, "chunk_0.csv")).readlines()) == 101 + ) else: filename = os.path.join(dir_name, "concatenated_scores.txt") score_writer.concatenate_write_scores(scores, filename) @@ -205,24 +223,24 @@ def test_checkpoint_bioalg_as_transformer(): run_pipeline(False) run_pipeline(False) # Checking if the checkpointng works - shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch + shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch os.makedirs(dir_name, exist_ok=True) # Dask run_pipeline(True) run_pipeline(True) # Checking if the checkpointng works - shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch + shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch os.makedirs(dir_name, exist_ok=True) # CSVWriter run_pipeline(False, CSVScoreWriter()) - run_pipeline(False, CSVScoreWriter()) # Checking if the checkpointng works - shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch + run_pipeline(False, CSVScoreWriter()) # Checking if the checkpointng works + shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch os.makedirs(dir_name, exist_ok=True) # CSVWriter + Dask run_pipeline(True, CSVScoreWriter()) - run_pipeline(True, CSVScoreWriter()) # Checking if the checkpointng works + run_pipeline(True, CSVScoreWriter()) # Checking if the checkpointng works def test_checkpoint_bioalg_as_bioalg(): @@ -231,12 +249,15 @@ def test_checkpoint_bioalg_as_bioalg(): def run_pipeline(with_dask, score_writer=FourColumnsScoreWriter()): database = DummyDatabase() - + transformer = _make_transformer_with_algorithm(dir_name) projector_file = transformer[2].estimator.estimator.projector_file biometric_algorithm = BioAlgorithmLegacy( - FakeAlgorithm(), base_dir=dir_name, score_writer=score_writer, projector_file=projector_file + FakeAlgorithm(), + base_dir=dir_name, + score_writer=score_writer, + projector_file=projector_file, ) vanilla_biometrics_pipeline = VanillaBiometricsPipeline( @@ -265,11 +286,11 @@ def test_checkpoint_bioalg_as_bioalg(): run_pipeline(False) run_pipeline(False) # Checking if the checkpointng works - shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch + shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch os.makedirs(dir_name, exist_ok=True) # Dask run_pipeline(True) run_pipeline(True) # Checking if the checkpointng works - shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch + shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch os.makedirs(dir_name, exist_ok=True) diff --git a/bob/bio/base/wrappers.py b/bob/bio/base/wrappers.py index dc2ef7d82176b28fd98c60f02a44e80adb4a3d6a..b37cb730d56674d045abdaa5a7402b89b6b7c108 100644 --- a/bob/bio/base/wrappers.py +++ b/bob/bio/base/wrappers.py @@ -11,6 +11,7 @@ from bob.bio.base.extractor import Extractor from bob.bio.base.algorithm import Algorithm import bob.pipelines as mario import os +from bob.bio.base.utils import is_argument_available def wrap_bob_legacy( @@ -18,7 +19,7 @@ def wrap_bob_legacy( dir_name, fit_extra_arguments=(("y", "subject"),), transform_extra_arguments=None, - dask_it=False + dask_it=False, ): """ Wraps either :any:`bob.bio.base.preprocessor.Preprocessor`, :any:`bob.bio.base.extractor.Extractor` @@ -47,33 +48,20 @@ def wrap_bob_legacy( """ if isinstance(bob_object, Preprocessor): - preprocessor_transformer = PreprocessorTransformer(bob_object) - transformer = wrap_preprocessor( - preprocessor_transformer, - features_dir=os.path.join(dir_name, "preprocessor"), - transform_extra_arguments=transform_extra_arguments, + transformer = wrap_checkpoint_preprocessor( + bob_object, features_dir=os.path.join(dir_name, "preprocessor"), ) elif isinstance(bob_object, Extractor): - extractor_transformer = ExtractorTransformer(bob_object) - path = os.path.join(dir_name, "extractor") - transformer = wrap_extractor( - extractor_transformer, - features_dir=path, - model_path=os.path.join(path, "extractor.pkl"), - transform_extra_arguments=transform_extra_arguments, - fit_extra_arguments=fit_extra_arguments, + transformer = wrap_checkpoint_extractor( + bob_object, + features_dir=os.path.join(dir_name, "extractor"), + model_path=dir_name, ) elif isinstance(bob_object, Algorithm): - path = os.path.join(dir_name, "algorithm") - algorithm_transformer = AlgorithmTransformer( - bob_object, projector_file=os.path.join(path, "Projector.hdf5") - ) - transformer = wrap_algorithm( - algorithm_transformer, - features_dir=path, - model_path=os.path.join(path, "algorithm.pkl"), - transform_extra_arguments=transform_extra_arguments, - fit_extra_arguments=fit_extra_arguments, + transformer = wrap_checkpoint_algorithm( + bob_object, + features_dir=os.path.join(dir_name, "algorithm"), + model_path=dir_name, ) else: raise ValueError( @@ -86,132 +74,351 @@ def wrap_bob_legacy( return transformer -def wrap_preprocessor( - preprocessor_transformer, features_dir=None, transform_extra_arguments=None, +def wrap_sample_preprocessor( + preprocessor, + transform_extra_arguments=(("annotations", "annotations"),), + **kwargs +): + """ + Wraps :any:`bob.bio.base.preprocessor.Preprocessor` with + :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` + + .. warning:: + This wrapper doesn't checkpoint data + + Parameters + ---------- + + preprocessor: :any:`bob.bio.base.preprocessor.Preprocessor` + Instance of :any:`bob.bio.base.transformers.PreprocessorTransformer` to be wrapped + + transform_extra_arguments: [tuple] + Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` + + """ + + transformer = PreprocessorTransformer(preprocessor) + return mario.wrap( + ["sample"], + transformer, + transform_extra_arguments=transform_extra_arguments, + ) + + +def wrap_checkpoint_preprocessor( + preprocessor, + features_dir=None, + transform_extra_arguments=(("annotations", "annotations"),), + load_func=None, + save_func=None, + extension=".hdf5", ): """ - Wraps :any:`bob.bio.base.transformers.PreprocessorTransformer` with + Wraps :any:`bob.bio.base.preprocessor.Preprocessor` with :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` Parameters ---------- - preprocessor_transformer: :any:`bob.bio.base.transformers.PreprocessorTransformer` + preprocessor: :any:`bob.bio.base.preprocessor.Preprocessor` Instance of :any:`bob.bio.base.transformers.PreprocessorTransformer` to be wrapped features_dir: str - Features directory to be checkpointed + Features directory to be checkpointed (see :any:bob.pipelines.CheckpointWrapper`). + + extension : str, optional + Extension o preprocessed files (see :any:bob.pipelines.CheckpointWrapper`). + + load_func : None, optional + Function that loads data to be preprocessed. + The default is :any:`bob.bio.base.preprocessor.Preprocessor.read_data` + + save_func : None, optional + Function that saves preprocessed data. + The default is :any:`bob.bio.base.preprocessor.Preprocessor.write_data` transform_extra_arguments: [tuple] Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` """ - if not isinstance(preprocessor_transformer, PreprocessorTransformer): - raise ValueError( - f"Expected an instance of PreprocessorTransformer, not {preprocessor_transformer}" - ) - + transformer = PreprocessorTransformer(preprocessor) return mario.wrap( ["sample", "checkpoint"], - preprocessor_transformer, - load_func=preprocessor_transformer.callable.read_data, - save_func=preprocessor_transformer.callable.write_data, + transformer, + load_func=load_func or preprocessor.read_data, + save_func=save_func or preprocessor.write_data, features_dir=features_dir, transform_extra_arguments=transform_extra_arguments, + extension=extension, ) -def wrap_extractor( - extractor_transformer, +def _prepare_extractor_sample_args( + extractor, transform_extra_arguments, fit_extra_arguments +): + if transform_extra_arguments is None and is_argument_available( + "metadata", extractor.__call__ + ): + transform_extra_arguments = (("metadata", "metadata"),) + + if ( + fit_extra_arguments is None + and extractor.requires_training + and extractor.split_training_data_by_client + ): + fit_extra_arguments = (("y", "subject"),) + + return transform_extra_arguments, fit_extra_arguments + + +def wrap_sample_extractor( + extractor, fit_extra_arguments=None, transform_extra_arguments=None, + model_path=None, + **kwargs, +): + """ + Wraps :any:`bob.bio.base.extractor.Extractor` with + :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` + + Parameters + ---------- + + extractor: :any:`bob.bio.base.extractor.Preprocessor` + Instance of :any:`bob.bio.base.transformers.ExtractorTransformer` to be wrapped + + transform_extra_arguments: [tuple], optional + Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` + + model_path: str + Path to `extractor_file` in :any:`bob.bio.base.extractor.Extractor` + + """ + + extractor_file = ( + os.path.join(model_path, "Extractor.hdf5") if model_path is not None else None + ) + + transformer = ExtractorTransformer(extractor, model_path=extractor_file) + + transform_extra_arguments, fit_extra_arguments = _prepare_extractor_sample_args( + extractor, transform_extra_arguments, fit_extra_arguments + ) + + return mario.wrap( + ["sample"], + transformer, + transform_extra_arguments=transform_extra_arguments, + fit_extra_arguments=fit_extra_arguments, + **kwargs, + ) + + +def wrap_checkpoint_extractor( + extractor, features_dir=None, + fit_extra_arguments=None, + transform_extra_arguments=None, + load_func=None, + save_func=None, + extension=".hdf5", model_path=None, + **kwargs, ): """ - Wraps :any:`bob.bio.base.transformers.ExtractorTransformer` with + Wraps :any:`bob.bio.base.extractor.Extractor` with :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` Parameters ---------- - extractor_transformer: :any:`bob.bio.base.transformers.ExtractorTransformer` + extractor: :any:`bob.bio.base.extractor.Preprocessor` Instance of :any:`bob.bio.base.transformers.ExtractorTransformer` to be wrapped features_dir: str - Features directory to be checkpointed + Features directory to be checkpointed (see :any:bob.pipelines.CheckpointWrapper`). - model_path: str - Path to checkpoint the model + extension : str, optional + Extension o preprocessed files (see :any:bob.pipelines.CheckpointWrapper`). + + load_func : None, optional + Function that loads data to be preprocessed. + The default is :any:`bob.bio.base.extractor.Extractor.read_feature` + + save_func : None, optional + Function that saves preprocessed data. + The default is :any:`bob.bio.base.extractor.Extractor.write_feature` fit_extra_arguments: [tuple] Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments` - transform_extra_arguments: [tuple] + transform_extra_arguments: [tuple], optional Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` + model_path: str + See :any:`TransformerExtractor`. + """ - if not isinstance(extractor_transformer, ExtractorTransformer): - raise ValueError( - f"Expected an instance of ExtractorTransformer, not {extractor_transformer}" - ) + extractor_file = ( + os.path.join(model_path, "Extractor.hdf5") if model_path is not None else None + ) + + model_file = ( + os.path.join(model_path, "Extractor.pkl") if model_path is not None else None + ) + transformer = ExtractorTransformer(extractor, model_path=extractor_file) + + transform_extra_arguments, fit_extra_arguments = _prepare_extractor_sample_args( + extractor, transform_extra_arguments, fit_extra_arguments + ) return mario.wrap( ["sample", "checkpoint"], - extractor_transformer, - load_func=extractor_transformer.callable.read_feature, - save_func=extractor_transformer.callable.write_feature, - model_path=model_path, + transformer, + load_func=load_func or extractor.read_feature, + save_func=save_func or extractor.write_feature, + model_path=model_file, features_dir=features_dir, transform_extra_arguments=transform_extra_arguments, fit_extra_arguments=fit_extra_arguments, + **kwargs, ) -def wrap_algorithm( - algorithm_transformer, +def _prepare_algorithm_sample_args( + algorithm, transform_extra_arguments, fit_extra_arguments +): + + if transform_extra_arguments is None and is_argument_available( + "metadata", algorithm.project + ): + transform_extra_arguments = (("metadata", "metadata"),) + + if ( + fit_extra_arguments is None + and algorithm.requires_projector_training + and algorithm.split_training_features_by_client + ): + fit_extra_arguments = (("y", "subject"),) + + return transform_extra_arguments, fit_extra_arguments + + +def wrap_sample_algorithm( + algorithm, + model_path=None, fit_extra_arguments=None, transform_extra_arguments=None, - features_dir=None, + **kwargs, +): + """ + Wraps :any:`bob.bio.base.algorithm.Algorithm` with + :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` + + Parameters + ---------- + + algorithm_transformer: :any:`bob.bio.base.algorithm.Algorithm` + Instance of :any:`bob.bio.base.transformers.AlgorithmTransformer` to be wrapped + + model_path: str + Path to `projector_file` in :any:`bob.bio.base.algorithm.Algorithm` + + fit_extra_arguments: [tuple] + Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments` + + transform_extra_arguments: [tuple] + Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` + + """ + + projector_file = ( + os.path.join(model_path, "Projector.hdf5") if model_path is not None else None + ) + + transformer = AlgorithmTransformer(algorithm, projector_file=projector_file) + + transform_extra_arguments, fit_extra_arguments = _prepare_algorithm_sample_args( + algorithm, transform_extra_arguments, fit_extra_arguments + ) + + return mario.wrap( + ["sample"], + transformer, + transform_extra_arguments=transform_extra_arguments, + fit_extra_arguments=fit_extra_arguments, + ) + + +def wrap_checkpoint_algorithm( + algorithm, model_path=None, + features_dir=None, + extension=".hdf5", + fit_extra_arguments=None, + transform_extra_arguments=None, + load_func=None, + save_func=None, + **kwargs, ): """ - Wraps :any:`bob.bio.base.transformers.AlgorithmTransformer` with + Wraps :any:`bob.bio.base.algorithm.Algorithm` with :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` Parameters ---------- - algorithm_transformer: :any:`bob.bio.base.transformers.AlgorithmTransformer` + algorithm_transformer: :any:`bob.bio.base.algorithm.Algorithm` Instance of :any:`bob.bio.base.transformers.AlgorithmTransformer` to be wrapped features_dir: str - Features directory to be checkpointed + Features directory to be checkpointed (see :any:bob.pipelines.CheckpointWrapper`). model_path: str Path to checkpoint the model + extension : str, optional + Extension o preprocessed files (see :any:bob.pipelines.CheckpointWrapper`). + fit_extra_arguments: [tuple] Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments` transform_extra_arguments: [tuple] Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` + load_func : None, optional + Function that loads data to be preprocessed. + The default is :any:`bob.bio.base.extractor.Extractor.read_feature` + + save_func : None, optional + Function that saves preprocessed data. + The default is :any:`bob.bio.base.extractor.Extractor.write_feature` + + """ - if not isinstance(algorithm_transformer, AlgorithmTransformer): - raise ValueError( - f"Expected an instance of AlgorithmTransformer, not {algorithm_transformer}" - ) + projector_file = ( + os.path.join(model_path, "Projector.hdf5") if model_path is not None else None + ) + + model_file = ( + os.path.join(model_path, "Projector.pkl") if model_path is not None else None + ) + transformer = AlgorithmTransformer(algorithm, projector_file=projector_file) + + transform_extra_arguments, fit_extra_arguments = _prepare_algorithm_sample_args( + algorithm, transform_extra_arguments, fit_extra_arguments + ) return mario.wrap( ["sample", "checkpoint"], - algorithm_transformer, - load_func=algorithm_transformer.callable.read_feature, - save_func=algorithm_transformer.callable.write_feature, - model_path=model_path, + transformer, + load_func=load_func or algorithm.read_feature, + save_func=save_func or algorithm.write_feature, + model_path=model_file, features_dir=features_dir, transform_extra_arguments=transform_extra_arguments, fit_extra_arguments=fit_extra_arguments,