diff --git a/bob/bio/base/algorithm/gmm.py b/bob/bio/base/algorithm/gmm.py index 7295dceb9b2c55df3d9c6287cf22d0750d5a8a23..765eaf2a2e544434f9093b457b63985a67cc4312 100644 --- a/bob/bio/base/algorithm/gmm.py +++ b/bob/bio/base/algorithm/gmm.py @@ -117,7 +117,7 @@ class GMM(GMMMachine, BioAlgorithm): update_means Decides wether the means of the Gaussians are updated while training. update_variances - Decides wether the variancess of the Gaussians are updated while training. + Decides wether the variances of the Gaussians are updated while training. enroll_iterations Number of iterations for the MAP GMM used for enrollment. enroll_update_weights @@ -125,7 +125,7 @@ class GMM(GMMMachine, BioAlgorithm): enroll_update_means Decides wether the means of the Gaussians are updated while enrolling. enroll_update_variances - Decides wether the variancess of the Gaussians are updated while enrolling. + Decides wether the variances of the Gaussians are updated while enrolling. enroll_relevance_factor For enrollment: MAP relevance factor as described in Reynolds paper. If None, will not apply Reynolds adaptation. @@ -246,7 +246,7 @@ class GMM(GMMMachine, BioAlgorithm): X = check_data_dim(X, expected_ndim=2) logger.debug( - f"Creating UBM machine with {self.n_gaussians} gaussians and {len(X)} samples" + f"Training UBM machine with {self.n_gaussians} gaussians and {len(X)} samples" ) super().fit(X) diff --git a/bob/bio/base/database/filelist/query.py b/bob/bio/base/database/filelist/query.py index 9cb376afbd4c2ca0d978122dcffe253c8bf802c2..dcb39dd3c396957c7e5a6b4f72b9a8cd41418387 100644 --- a/bob/bio/base/database/filelist/query.py +++ b/bob/bio/base/database/filelist/query.py @@ -3,6 +3,7 @@ import logging import os +from bob.bio.base.database.legacy import check_parameters_for_validity from bob.bio.base.utils.annotations import read_annotation_file from .. import BioFile, ZTBioDatabase @@ -267,7 +268,7 @@ class FileListBioDatabase(ZTBioDatabase): group, self.protocol, **self.z_probe_options ) else: - logger.warn( + logger.warning( "ZT score files are requested, but no such files are defined in group %s for protocol %s", group, self.protocol, @@ -390,7 +391,7 @@ class FileListBioDatabase(ZTBioDatabase): ``True`` if the all file lists for ZT score normalization exist, otherwise ``False``. """ protocol = protocol or self.protocol - groups = self.check_parameters_for_validity( + groups = check_parameters_for_validity( groups, "group", self.groups(protocol, add_world=False) ) @@ -492,7 +493,7 @@ class FileListBioDatabase(ZTBioDatabase): The client id for the given model id, if found. """ protocol = self.protocol - groups = self.check_parameters_for_validity( + groups = check_parameters_for_validity( group, "group", self.groups(protocol), @@ -534,7 +535,7 @@ class FileListBioDatabase(ZTBioDatabase): The client id for the given model id of a T-Norm model, if found. """ protocol = self.protocol - groups = self.check_parameters_for_validity( + groups = check_parameters_for_validity( group, "group", self.groups(protocol, add_world=False) ) @@ -584,7 +585,7 @@ class FileListBioDatabase(ZTBioDatabase): """ protocol = protocol or self.protocol - groups = self.check_parameters_for_validity( + groups = check_parameters_for_validity( groups, "group", self.groups(protocol), @@ -613,7 +614,7 @@ class FileListBioDatabase(ZTBioDatabase): """ protocol = protocol or self.protocol - groups = self.check_parameters_for_validity( + groups = check_parameters_for_validity( groups, "group", self.groups(protocol, add_world=False) ) @@ -639,7 +640,7 @@ class FileListBioDatabase(ZTBioDatabase): """ protocol = protocol or self.protocol - groups = self.check_parameters_for_validity( + groups = check_parameters_for_validity( groups, "group", self.groups(protocol, add_world=False) ) @@ -675,7 +676,7 @@ class FileListBioDatabase(ZTBioDatabase): A list containing all the model ids which have the given properties. """ protocol = protocol or self.protocol - groups = self.check_parameters_for_validity( + groups = check_parameters_for_validity( groups, "group", self.groups(protocol=protocol) ) @@ -700,7 +701,7 @@ class FileListBioDatabase(ZTBioDatabase): A list containing all the T-Norm model ids belonging to the given group. """ protocol = protocol or self.protocol - groups = self.check_parameters_for_validity( + groups = check_parameters_for_validity( groups, "group", self.groups(protocol, add_world=False) ) @@ -759,16 +760,16 @@ class FileListBioDatabase(ZTBioDatabase): "To be able to use the 'classes' keyword, please use the 'for_scores.lst' list file." ) - purposes = self.check_parameters_for_validity( + purposes = check_parameters_for_validity( purposes, "purpose", ("enroll", "probe") ) - groups = self.check_parameters_for_validity( + groups = check_parameters_for_validity( groups, "group", self.groups(protocol), default_parameters=self.groups(protocol, add_subworld=False), ) - classes = self.check_parameters_for_validity( + classes = check_parameters_for_validity( classes, "class", ("client", "impostor") ) @@ -902,7 +903,7 @@ class FileListBioDatabase(ZTBioDatabase): A list of :py:class:`BioFile` objects considering all the filtering criteria. """ protocol = protocol or self.protocol - groups = self.check_parameters_for_validity( + groups = check_parameters_for_validity( groups, "group", self.groups(protocol, add_world=False) ) @@ -943,7 +944,7 @@ class FileListBioDatabase(ZTBioDatabase): """ protocol = protocol or self.protocol - groups = self.check_parameters_for_validity( + groups = check_parameters_for_validity( groups, "group", self.groups(protocol, add_world=False) ) diff --git a/bob/bio/base/pipelines/__init__.py b/bob/bio/base/pipelines/__init__.py index 1f5e885f6206b1e0ccc2f879b905f31ab095fa8a..415523486da7e8341a1957e4bf2f6177297a7a31 100644 --- a/bob/bio/base/pipelines/__init__.py +++ b/bob/bio/base/pipelines/__init__.py @@ -26,6 +26,7 @@ from .score_post_processor import ( # noqa: F401 from .entry_points import ( # noqa: F401 execute_pipeline_simple, execute_pipeline_score_norm, + execute_pipeline_train, ) diff --git a/bob/bio/base/pipelines/entry_points.py b/bob/bio/base/pipelines/entry_points.py index 237d9c31b8ca4d55f8764a78202cecf1169b25bb..03aa26a5cd5531c02a3f5eeb41a9f861f08bf740 100644 --- a/bob/bio/base/pipelines/entry_points.py +++ b/bob/bio/base/pipelines/entry_points.py @@ -1,22 +1,35 @@ +import glob import logging import os +import pickle +import random + +from typing import Optional, Union import dask.bag from dask.delayed import Delayed +from sklearn.pipeline import Pipeline from bob.bio.base.pipelines import ( BioAlgDaskWrapper, CSVScoreWriter, + Database, FourColumnsScoreWriter, PipelineScoreNorm, + PipelineSimple, TNormScores, ZNormScores, checkpoint_pipeline_simple, dask_bio_pipeline, is_biopipeline_checkpointed, ) -from bob.pipelines import estimator_requires_fit, is_instance_nested, wrap +from bob.pipelines import ( + DaskWrapper, + estimator_requires_fit, + is_instance_nested, + wrap, +) from bob.pipelines.distributed import dask_get_partition_size from bob.pipelines.distributed.sge import SGEMultipleQueuesCluster @@ -137,7 +150,7 @@ def execute_pipeline_simple( os.path.join(output, "./tmp") ) - # Checkpoint if it's already checkpointed + # Checkpoint if it's not already checkpointed if checkpoint and not is_biopipeline_checkpointed(pipeline): hash_fn = database.hash_fn if hasattr(database, "hash_fn") else None pipeline = checkpoint_pipeline_simple( @@ -190,11 +203,11 @@ def execute_pipeline_simple( ) elif dask_n_partitions is not None or dask_n_workers is not None: # Divide each Set in a user-defined number of partitions - logger.debug("Splitting data with fixed number of partitions.") - pipeline = dask_bio_pipeline( - pipeline, - npartitions=dask_n_partitions or dask_n_workers, + n_partitions = dask_n_partitions or dask_n_workers + logger.debug( + f"Splitting data with fixed number of partitions: {n_partitions}." ) + pipeline = dask_bio_pipeline(pipeline, npartitions=n_partitions) else: # Split in max_jobs partitions or revert to the default behavior of # dask.Bag from_sequence: partition_size = 100 @@ -466,3 +479,182 @@ def execute_pipeline_score_norm( ) _ = compute_scores(zt_normed_scores, dask_client) """ + + +def execute_pipeline_train( + pipeline: Union[PipelineSimple, Pipeline], + database: Database, + dask_client: Optional[dask.distributed.Client] = None, + output: str = "./results", + checkpoint: bool = True, + dask_n_partitions: Optional[int] = None, + dask_partition_size: Optional[int] = None, + dask_n_workers: Optional[int] = None, + checkpoint_dir: Optional[str] = None, + force: bool = False, + split_training: bool = False, + n_splits: int = 3, + **kwargs, +): + """Executes only the training part of a pipeline. + + When running from a script, use this function instead of the click command in + ``bob.bio.base.script.pipeline_train``. + + Parameters + ---------- + + pipeline: + A constructed ``PipelineSimple`` object (the ``transformer`` will be extracted + for training), or an ``sklearn.Pipeline``. + + database: + A database interface instance + + dask_client: + A Dask client instance used to run the experiment in parallel on multiple + machines, or locally. + Basic configs can be found in ``bob.pipelines.config.distributed``. + + dask_n_partitions: + Specifies a number of partitions to split the data into. + + dask_partition_size: + Specifies a data partition size when using dask. Ignored when dask_n_partitions + is set. + + dask_n_workers: + Sets the starting number of Dask workers. Does not prevent Dask from requesting + more or releasing workers depending on load. + + output: + Path where the scores will be saved. + + checkpoint: + Whether checkpoint files will be created for every step of the pipelines. + + checkpoint_dir: + If `checkpoint` is set, this path will be used to save the checkpoints. + If `None`, the content of `output` will be used. + + force: + If set, it will force generate all the checkpoints of an experiment. This option doesn't work if `--memory` is set + + split_training: + If set, the background model will be trained on multiple partitions of the data. + + n_splits: + Number of splits to use when splitting the data. + """ + + logger.debug(f"Unused arguments: {kwargs=}") + if not os.path.exists(output): + os.makedirs(output, exist_ok=True) + + # Setting the `checkpoint_dir` + if checkpoint_dir is None: + checkpoint_dir = output + else: + os.makedirs(checkpoint_dir, exist_ok=True) + + if isinstance(pipeline, PipelineSimple): + pipeline = pipeline.transformer + + # Checkpoint (only features, not the model) + if checkpoint: + hash_fn = database.hash_fn if hasattr(database, "hash_fn") else None + wrap( + ["checkpoint"], + pipeline, + features_dir=checkpoint_dir, + model_path=None, + hash_fn=hash_fn, + force=force, + ) + + if not estimator_requires_fit(pipeline): + raise ValueError( + "Estimator does not require fitting. No training necessary." + ) + + background_model_samples = database.background_model_samples() + + if dask_client is not None: + # Scaling up + if dask_n_workers is not None and not isinstance(dask_client, str): + dask_client.cluster.scale(dask_n_workers) + + if dask_partition_size is not None: + logger.debug( + f"Splitting data with fixed size partitions: {dask_partition_size}." + ) + pipeline = wrap( + ["dask"], pipeline, partition_size=dask_partition_size + ) + elif dask_n_partitions is not None or dask_n_workers is not None: + # Divide each Set in a user-defined number of partitions + n_partitions = dask_n_partitions or dask_n_workers + logger.debug( + f"Splitting data with fixed number of partitions: {n_partitions}." + ) + pipeline = wrap(["dask"], pipeline, npartitions=n_partitions) + else: + # Split in max_jobs partitions or revert to the default behavior of + # dask.Bag from_sequence: partition_size = 100 + n_jobs = None + if not isinstance(dask_client, str) and isinstance( + dask_client.cluster, SGEMultipleQueuesCluster + ): + logger.debug( + "Splitting data according to the number of available workers." + ) + n_jobs = dask_client.cluster.sge_job_spec["default"]["max_jobs"] + logger.debug(f"{n_jobs} partitions will be created.") + pipeline = wrap(["dask"], pipeline, npartitions=n_jobs) + + logger.info("Running the pipeline training") + if split_training: + start_step = -1 + # Look at step files, and assess if we can load the last one + for step_file in glob.glob( + os.path.join(output, "train_pipeline_step_*.pkl") + ): + to_rem = os.path.join(output, "train_pipeline_step_") + file_step = int(step_file.replace(to_rem, "").replace(".pkl", "")) + start_step = max(start_step, file_step) + if start_step > -1: + logger.debug("Found pipeline training step. Loading it.") + last_step_file = os.path.join( + output, f"train_pipeline_step_{start_step}.pkl" + ) + with open(last_step_file, "rb") as start_file: + pipeline = pickle.load(start_file) + start_step += 1 # Loaded step is i -> training starts a i+1 + logger.info(f"Starting from training step {start_step}") + + random.seed(0) + random.shuffle(background_model_samples) + + for partition_i in range(start_step, n_splits): + logger.info( + f"Training with partition {partition_i} ({partition_i+1}/{n_splits})" + ) + start = len(background_model_samples) // n_splits * partition_i + end = len(background_model_samples) // n_splits * (partition_i + 1) + _ = pipeline.fit(background_model_samples[start:end]) + step_path = os.path.join( + output, f"train_pipeline_step_{partition_i}.pkl" + ) + with open(step_path, "wb") as f: + pickle.dump(pipeline, f) + else: + _ = pipeline.fit(background_model_samples) + + # Save each fitted transformer + for transformer_name, transformer in pipeline.steps: + if transformer._get_tags()["requires_fit"]: + if isinstance(transformer, DaskWrapper): + transformer = transformer.estimator + step_path = os.path.join(output, f"{transformer_name}.pkl") + with open(step_path, "wb") as f: + pickle.dump(transformer, f) diff --git a/bob/bio/base/pipelines/wrappers.py b/bob/bio/base/pipelines/wrappers.py index 2bc3414b55a1717c8cac87b81e5c94f4fe121bfd..260740a37881c71f53d9a3c28a5d7e3777a380c0 100644 --- a/bob/bio/base/pipelines/wrappers.py +++ b/bob/bio/base/pipelines/wrappers.py @@ -344,7 +344,7 @@ def is_biopipeline_checkpointed(pipeline): """ - # We have to check if biomtric_algorithm is checkpointed + # We have to check if biometric_algorithm is checkpointed return is_instance_nested( pipeline, "biometric_algorithm", BioAlgCheckpointWrapper ) diff --git a/bob/bio/base/script/gen.py b/bob/bio/base/script/gen.py index 1105c5fcf698d09729071579205cb6605bf94141..ce84850286c6faa368311c3ab0f3bdfa3ccf71cc 100644 --- a/bob/bio/base/script/gen.py +++ b/bob/bio/base/script/gen.py @@ -8,7 +8,6 @@ import click import numpy from bob.extension.scripts.click_helper import verbosity_option -from bob.io.base import create_directories_safe logger = logging.getLogger(__name__) @@ -98,7 +97,7 @@ def write_scores_to_file( If 5-colum format, else 4-column """ logger.debug(f"Creating result directories ('{filename}').") - create_directories_safe(os.path.dirname(filename)) + os.makedirs(os.path.dirname(filename), exist_ok=True) s_subjects = ["x%d" % i for i in range(n_subjects)] logger.debug("Writing scores to files.") diff --git a/bob/bio/base/script/pipeline_train.py b/bob/bio/base/script/pipeline_train.py new file mode 100644 index 0000000000000000000000000000000000000000..dbbd3b9f7b4ab55f572c439ffa68e6c27bb8ef7c --- /dev/null +++ b/bob/bio/base/script/pipeline_train.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + + +"""Executes only the train part of a biometric pipeline""" + +import logging + +import click + +from bob.extension.scripts.click_helper import ( + ConfigCommand, + ResourceOption, + verbosity_option, +) +from bob.pipelines.distributed import VALID_DASK_CLIENT_STRINGS + +logger = logging.getLogger(__name__) + + +EPILOG = """\b + +Command line examples\n +----------------------- + +$ bob bio pipeline train -vv DATABASE PIPELINE + +See the help of the CONFIG argument on top of this help message +for a list of available configurations. + +It is possible to provide database and pipeline through a configuration file. +Generate an example configuration file with: + +$ bob bio pipeline train --dump-config my_experiment.py + +and execute it with: + +$ bob bio pipeline train -vv my_experiment.py + +my_experiment.py must contain the following elements: + + >>> pipeline = ... # A scikit-learn pipeline wrapped with bob.pipelines' SampleWrapper\n + >>> database = .... # Biometric Database (class that implements the methods: `background_model_samples`, `references` and `probes`)" +\b""" + + +@click.command( + name="train", + entry_point_group="bob.bio.config", + cls=ConfigCommand, + epilog=EPILOG, +) +@click.option( + "--pipeline", + "-p", + required=True, + entry_point_group="bob.bio.pipeline", + help="A PipelineSimple or an sklearn.pipeline", + cls=ResourceOption, +) +@click.option( + "--database", + "-d", + entry_point_group="bob.bio.database", + required=True, + help="Biometric Database connector (class that implements the methods: `background_model_samples`, `references` and `probes`)", + cls=ResourceOption, +) +@click.option( + "--dask-client", + "-l", + entry_point_group="dask.client", + string_exceptions=VALID_DASK_CLIENT_STRINGS, + default="single-threaded", + help="Dask client for the execution of the pipeline.", + cls=ResourceOption, +) +@click.option( + "--output", + "-o", + show_default=True, + default="results", + help="Name of output directory where output files will be saved.", + cls=ResourceOption, +) +@click.option( + "--memory", + "-m", + is_flag=True, + help="If set, it will run the experiment keeping all objects on memory with nothing checkpointed. If not set, checkpoints will be saved in `--output`.", + cls=ResourceOption, +) +@click.option( + "--checkpoint-dir", + "-c", + show_default=True, + default=None, + help="Name of output directory where the checkpoints will be saved. In case --memory is not set, checkpoints will be saved in this directory.", + cls=ResourceOption, +) +@click.option( + "--dask-partition-size", + "-s", + help="If using Dask, this option defines the max size of each dask.bag.partition. " + "Use this option if the current heuristic that sets this value doesn't suit your experiment. " + "(https://docs.dask.org/en/latest/bag-api.html?highlight=partition_size#dask.bag.from_sequence).", + default=None, + type=click.INT, + cls=ResourceOption, +) +@click.option( + "--dask-n-partitions", + "-n", + help="If using Dask, this option defines a fixed number of dask.bag.partition for " + "each set of data. Use this option if the current heuristic that sets this value " + "doesn't suit your experiment." + "(https://docs.dask.org/en/latest/bag-api.html?highlight=partition_size#dask.bag.from_sequence).", + default=None, + type=click.INT, + cls=ResourceOption, +) +@click.option( + "--dask-n-workers", + "-w", + help="If using Dask, this option defines the number of workers to start your experiment. " + "Dask automatically scales up/down the number of workers due to the current load of tasks to be solved. " + "Use this option if the current amount of workers set to start an experiment doesn't suit you.", + default=None, + type=click.INT, + cls=ResourceOption, +) +@click.option( + "--force", + "-f", + is_flag=True, + help="If set, it will force generate all the checkpoints of an experiment. This option doesn't work if `--memory` is set", + cls=ResourceOption, +) +@click.option( + "--no-dask", + is_flag=True, + help="If set, it will not use Dask to run the experiment.", + cls=ResourceOption, +) +@click.option( + "--split-training", + is_flag=True, + help="Splits the training set in partitions and trains the pipeline in multiple steps.", + cls=ResourceOption, +) +@click.option( + "--n-splits", + default=3, + help="Number of partitions to split the training set in. " + "Each partition will be trained in a separate step.", + cls=ResourceOption, +) +@verbosity_option(cls=ResourceOption) +def pipeline_train( + pipeline, + database, + dask_client, + output, + memory, + checkpoint_dir, + dask_partition_size, + dask_n_workers, + dask_n_partitions, + force, + no_dask, + split_training, + n_splits, + **kwargs, +): + """Runs the training part of a biometrics pipeline. + + This pipeline consists only of one component, contrary to the ``simple`` pipeline. + This component is a scikit-learn ``Pipeline``, where a sequence of transformations + of the input data is defined. + + The pipeline is trained on the database and the resulting model is saved in the + output directory. + + It is possible to split the training data in multiple partitions that will be + used to train the pipeline in multiple steps, helping with big datasets that would + not fit in memory if trained all at once. Passing the ``--split-training`` option + will split the training data in ``--n-splits`` partitions and train the pipeline + sequentially with each partition. The pipeline must support "continuous learning", + (a call to ``fit`` on an already trained pipeline should continue the training). + """ + + from bob.bio.base.pipelines import execute_pipeline_train + + if no_dask: + dask_client = None + + checkpoint = not memory + + logger.debug("Executing pipeline training with:") + logger.debug(f"pipeline: {pipeline}") + logger.debug(f"database: {database}") + + execute_pipeline_train( + pipeline=pipeline, + database=database, + dask_client=dask_client, + output=output, + checkpoint=checkpoint, + dask_partition_size=dask_partition_size, + dask_n_partitions=dask_n_partitions, + dask_n_workers=dask_n_workers, + checkpoint_dir=checkpoint_dir, + force=force, + split_training=split_training, + n_splits=n_splits, + **kwargs, + ) + + logger.info(f"Experiment finished ! ({output=})") diff --git a/bob/bio/base/test/test_pipeline_simple.py b/bob/bio/base/test/test_pipeline_simple.py index 7e5307b0b76abd0c362cdf3dc0a7200eb44b325b..0f2059fda6a886b6f4f0947d0cd93b4c5e702f35 100644 --- a/bob/bio/base/test/test_pipeline_simple.py +++ b/bob/bio/base/test/test_pipeline_simple.py @@ -27,7 +27,7 @@ from bob.bio.base.pipelines import ( from bob.bio.base.script.pipeline_simple import ( pipeline_simple as pipeline_simple_cli, ) -from bob.bio.base.test.test_transformers import FakeExtractor, FakePreprocesor +from bob.bio.base.test.test_transformers import FakeExtractor, FakePreprocessor from bob.bio.base.wrappers import wrap_bob_legacy from bob.extension.scripts.click_helper import assert_click_runner_result from bob.pipelines import DelayedSample, Sample, SampleSet @@ -191,7 +191,7 @@ class DistanceWithTags(Distance): def _make_transformer(dir_name): pipeline = make_pipeline( wrap_bob_legacy( - FakePreprocesor(), + FakePreprocessor(), dir_name, transform_extra_arguments=(("annotations", "annotations"),), ), diff --git a/bob/bio/base/test/test_pipeline_train.py b/bob/bio/base/test/test_pipeline_train.py new file mode 100644 index 0000000000000000000000000000000000000000..f54dded1c2c4da8caa7276f0de922579612e69eb --- /dev/null +++ b/bob/bio/base/test/test_pipeline_train.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python +# @author: Yannick Dayer <yannick.dayer@idiap.ch> +# @date: Fri 19 Aug 2022 14:37:01 UTC+02 + +import glob +import os +import tempfile + +import pytest + +from click.testing import CliRunner +from sklearn.base import BaseEstimator +from sklearn.pipeline import Pipeline + +from bob.bio.base.pipelines.entry_points import execute_pipeline_train +from bob.bio.base.script.pipeline_train import ( + pipeline_train as pipeline_train_cli, +) +from bob.bio.base.test.test_pipeline_simple import DummyDatabase +from bob.bio.base.test.test_transformers import FakeExtractor, FakePreprocessor +from bob.bio.base.wrappers import wrap_bob_legacy +from bob.extension.scripts.click_helper import assert_click_runner_result +from bob.pipelines import wrap + + +class FittableTransformer(BaseEstimator): + def __init__(self): + super().__init__() + self.fitted_count = 0 + + def fit(self, X, y=None): + self.fitted_count += 1 + return self + + def transform(self, X): + return X + self.fitted_count + + def _more_tags(self): + return {"requires_fit": True} + + +def _make_transformer(dir_name): + pipeline = Pipeline( + [ + ( + "preprocessor", + wrap_bob_legacy( + FakePreprocessor(), + dir_name, + transform_extra_arguments=(("annotations", "annotations"),), + ), + ), + ( + "extractor", + wrap_bob_legacy( + FakeExtractor(), + dir_name, + ), + ), + ("fittable_transformer", wrap(["sample"], FittableTransformer())), + ] + ) + + return pipeline + + +def test_pipeline_train_function(): + with tempfile.TemporaryDirectory() as output: + pipeline = _make_transformer(output) + database = DummyDatabase() + execute_pipeline_train(pipeline, database, output=output) + assert os.path.isfile(os.path.join(output, "fittable_transformer.pkl")) + + +def _create_test_config_pipeline_simple(path): + with open(path, "w") as f: + f.write( + """ +from bob.bio.base.test.test_pipeline_train import DummyDatabase, _make_transformer +from bob.bio.base.pipelines import PipelineSimple +from bob.bio.base.algorithm import Distance + +database = DummyDatabase() + +transformer = _make_transformer(".") + +biometric_algorithm = Distance() + +pipeline = PipelineSimple( + transformer, + biometric_algorithm, + None, +) +""" + ) + + +def _create_test_config_pipeline_sklearn(path): + with open(path, "w") as f: + f.write( + """ +from bob.bio.base.test.test_pipeline_train import DummyDatabase, _make_transformer + +database = DummyDatabase() + +pipeline = _make_transformer(".") +""" + ) + + +@pytest.mark.parametrize( + "options,pipeline_simple", + [ + (["--no-dask", "--memory"], True), + (["--no-dask", "--memory"], False), + (["--no-dask"], True), + (["--no-dask"], False), + (["--memory"], True), + (["--memory"], False), + ([], True), + ([], False), + ], +) +def test_pipeline_click_cli( + options, + pipeline_simple, + expected_outputs=("results/fittable_transformer.pkl",), +): + runner = CliRunner() + with runner.isolated_filesystem(): + + if pipeline_simple: + _create_test_config_pipeline_simple("config.py") + else: + _create_test_config_pipeline_sklearn("config.py") + result = runner.invoke( + pipeline_train_cli, + [ + "-vv", + "config.py", + ] + + options, + ) + assert_click_runner_result(result) + # check for expected_output + output_files = glob.glob("results/**", recursive=True) + nl = "\n -" + err_msg = f"Found only:\n- {nl.join(output_files)}\nin output directory given the options: {options}, and with {'PipelineSimple' if pipeline_simple else 'sklearn pipeline'}" + for out in expected_outputs: + assert os.path.isfile(out), err_msg diff --git a/bob/bio/base/test/test_transformers.py b/bob/bio/base/test/test_transformers.py index 48f44da4f9accfb5e4bf814ea79b5c79a65c16b8..1c49dbe608882d1b21a94809728db20eb63254b4 100644 --- a/bob/bio/base/test/test_transformers.py +++ b/bob/bio/base/test/test_transformers.py @@ -17,7 +17,7 @@ from bob.bio.base.transformers import ( from bob.pipelines import CheckpointWrapper, Sample, SampleWrapper -class FakePreprocesor(Preprocessor): +class FakePreprocessor(Preprocessor): def __call__(self, data, annotations=None): return data + annotations @@ -82,7 +82,7 @@ def assert_checkpoints(transformed_sample, dir_name): def test_preprocessor(): - preprocessor = FakePreprocesor() + preprocessor = FakePreprocessor() preprocessor_transformer = PreprocessorTransformer(preprocessor) # Testing sample diff --git a/conda/meta.yaml b/conda/meta.yaml index ea3fd922fef42e9ab331f1d646142620440b0c84..f9c973058dcd688f8d104354a22d3a43d382b1e0 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -40,7 +40,7 @@ requirements: - {{ pin_compatible('click') }} - {{ pin_compatible('click-plugins') }} - {{ pin_compatible('dask') }} - - {{ pin_compatible('numpy') }} + - {{ pin_compatible('numpy', max_pin='x.x') }} - {{ pin_compatible('pandas') }} - {{ pin_compatible('scipy') }} - {{ pin_compatible('tabulate') }} @@ -64,7 +64,8 @@ test: - bob bio pipeline simple --help - bob bio pipeline score-norm --help - bob bio pipeline transform --help - - pytest --verbose --cov {{ name }} --cov-report term-missing --cov-report html:{{ project_dir }}/sphinx/coverage --cov-report xml:{{ project_dir }}/coverage.xml --pyargs {{ name }} + - bob bio pipeline train --help + - pytest --verbose --cov {{ name }} --cov-report term-missing --cov-report html:{{ project_dir }}/sphinx/coverage --cov-report xml:{{ project_dir }}/coverage.xml --junitxml={{ project_dir }}/test_results.xml --pyargs {{ name }} - sphinx-build -aEW {{ project_dir }}/doc {{ project_dir }}/sphinx - sphinx-build -aEb doctest {{ project_dir }}/doc sphinx - conda inspect linkages -p $PREFIX {{ name }} # [not win] diff --git a/setup.py b/setup.py index 431e003140df260a51289497e6fb5f888da5a833..091cff37297bde41dc80ecd3d53bb97ec667b6bc 100644 --- a/setup.py +++ b/setup.py @@ -107,6 +107,7 @@ setup( "simple = bob.bio.base.script.pipeline_simple:pipeline_simple", "score-norm = bob.bio.base.script.pipeline_score_norm:pipeline_score_norm", "transform = bob.bio.base.script.pipeline_transform:pipeline_transform", + "train = bob.bio.base.script.pipeline_train:pipeline_train", ], # Vulnerability analysis commands "bob.vuln.cli": [