Skip to content
Snippets Groups Projects
Commit 3a991006 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Replaced callable by instance in the transformers

parent ce29f3bb
No related branches found
No related tags found
2 merge requests!185Wrappers and aggregators,!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #39671 passed
......@@ -26,56 +26,56 @@ class AlgorithmTransformer(TransformerMixin, BaseEstimator):
Parameters
----------
callable: ``collections.callable``
instance: ``collections.callable``
Callable function that instantiates the bob.bio.base.algorithm.Algorithm
"""
def __init__(
self, callable, projector_file=None, **kwargs,
self, instance, projector_file=None, **kwargs,
):
if not isinstance(callable, Algorithm):
if not isinstance(instance, Algorithm):
raise ValueError(
"`callable` should be an instance of `bob.bio.base.extractor.Algorithm`"
"`instance` should be an instance of `bob.bio.base.extractor.Algorithm`"
)
if callable.requires_projector_training and (
if instance.requires_projector_training and (
projector_file is None or projector_file == ""
):
raise ValueError(
f"`projector_file` needs to be set if extractor {callable} requires training"
f"`projector_file` needs to be set if extractor {instance} requires training"
)
if not is_picklable(callable):
raise ValueError(f"{callable} needs to be picklable")
if not is_picklable(instance):
raise ValueError(f"{instance} needs to be picklable")
self.callable = callable
self.instance = instance
self.projector_file = projector_file
super().__init__(**kwargs)
def fit(self, X, y=None):
if not self.callable.requires_projector_training:
if not self.instance.requires_projector_training:
return self
training_data = X
if self.callable.split_training_features_by_client:
if self.instance.split_training_features_by_client:
training_data = split_X_by_y(X, y)
os.makedirs(os.path.dirname(self.projector_file), exist_ok=True)
self.callable.train_projector(training_data, self.projector_file)
self.instance.train_projector(training_data, self.projector_file)
return self
def transform(self, X, metadata=None):
if metadata is None:
return [self.callable.project(data) for data in X]
return [self.instance.project(data) for data in X]
else:
return [
self.callable.project(data, metadata)
self.instance.project(data, metadata)
for data, metadata in zip(X, metadata)
]
def _more_tags(self):
return {
"stateless": not self.callable.requires_projector_training,
"requires_fit": self.callable.requires_projector_training,
"stateless": not self.instance.requires_projector_training,
"requires_fit": self.instance.requires_projector_training,
}
......@@ -13,7 +13,7 @@ class ExtractorTransformer(TransformerMixin, BaseEstimator):
Parameters
----------
callable: ``collections.Callable``
instance: ``collections.callable``
Instance of `bob.bio.base.extractor.Extractor`
model_path: ``str``
......@@ -22,44 +22,44 @@ class ExtractorTransformer(TransformerMixin, BaseEstimator):
"""
def __init__(
self, callable, model_path=None, **kwargs,
self, instance, model_path=None, **kwargs,
):
if not isinstance(callable, Extractor):
if not isinstance(instance, Extractor):
raise ValueError(
"`callable` should be an instance of `bob.bio.base.extractor.Extractor`"
"`instance` should be an instance of `bob.bio.base.extractor.Extractor`"
)
if callable.requires_training and (model_path is None or model_path == ""):
if instance.requires_training and (model_path is None or model_path == ""):
raise ValueError(
f"`model_path` needs to be set if extractor {callable} requires training"
f"`model_path` needs to be set if extractor {instance} requires training"
)
self.callable = callable
self.instance = instance
self.model_path = model_path
super().__init__(**kwargs)
def fit(self, X, y=None):
if not self.callable.requires_training:
if not self.instance.requires_training:
return self
training_data = X
if self.callable.split_training_data_by_client:
if self.instance.split_training_data_by_client:
training_data = split_X_by_y(X, y)
self.callable.train(training_data, self.model_path)
self.instance.train(training_data, self.model_path)
return self
def transform(self, X, metadata=None):
if metadata is None:
return [self.callable(data) for data in X]
return [self.instance(data) for data in X]
else:
return [
self.callable(data, metadata) for data, metadata in zip(X, metadata)
self.instance(data, metadata) for data, metadata in zip(X, metadata)
]
def _more_tags(self):
return {
"stateless": not self.callable.requires_training,
"requires_fit": self.callable.requires_training,
"stateless": not self.instance.requires_training,
"requires_fit": self.instance.requires_training,
}
......@@ -11,7 +11,7 @@ class PreprocessorTransformer(TransformerMixin, BaseEstimator):
Parameters
----------
callable: ``collections.Callable``
instance: ``collections.callable``
Instance of `bob.bio.base.preprocessor.Preprocessor`
......@@ -19,21 +19,21 @@ class PreprocessorTransformer(TransformerMixin, BaseEstimator):
def __init__(
self,
callable,
instance,
**kwargs,
):
if not isinstance(callable, Preprocessor):
raise ValueError("`callable` should be an instance of `bob.bio.base.preprocessor.Preprocessor`")
if not isinstance(instance, Preprocessor):
raise ValueError("`instance` should be an instance of `bob.bio.base.preprocessor.Preprocessor`")
self.callable = callable
self.instance = instance
super().__init__(**kwargs)
def transform(self, X, annotations=None):
if annotations is None:
return [self.callable(data) for data in X]
return [self.instance(data) for data in X]
else:
return [self.callable(data, annot) for data, annot in zip(X, annotations)]
return [self.instance(data, annot) for data, annot in zip(X, annotations)]
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment