Skip to content
Snippets Groups Projects
Commit 6fd5d9b3 authored by Christophe Ecabert's avatar Christophe Ecabert
Browse files

[CHG] SampleWrapper to chose the type of output

[CHG] DaskWrapper to avoid calling `fit` multiple times
parent be8b4c1b
No related branches found
No related tags found
1 merge request!97pipeline wrappers tweaks
Pipeline #62543 passed
......@@ -21,7 +21,7 @@ from sklearn.preprocessing import FunctionTransformer
import bob.io.base
from .sample import DelayedSample, SampleBatch, SampleSet
from .sample import DelayedSample, Sample, SampleBatch, SampleSet
logger = logging.getLogger(__name__)
......@@ -259,6 +259,9 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
transform_extra_arguments : [tuple]
Similar to ``fit_extra_arguments`` but for the transform and other similar
methods.
delayed_output : bool
If ``True``, the output will be an instance of ``DelayedSample`` otherwise it
will be an instance of ``Sample``.
"""
def __init__(
......@@ -268,6 +271,7 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
fit_extra_arguments=None,
output_attribute=None,
input_attribute=None,
delayed_output=True,
**kwargs,
):
super().__init__(**kwargs)
......@@ -282,6 +286,7 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
fit_extra_arguments or bob_tags["bob_fit_extra_input"]
)
self.output_attribute = output_attribute or bob_tags["bob_output"]
self.delayed_output = delayed_output
def _samples_transform(self, samples, method_name):
# Transform either samples or samplesets
......@@ -312,10 +317,15 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
setattr(s, self.output_attribute, delayed(i))
new_samples = samples
else:
new_samples = [
DelayedSample(partial(delayed, index=i), parent=s)
for i, s in enumerate(samples)
]
new_samples = []
for i, s in enumerate(samples):
if self.delayed_output:
sample = DelayedSample(
partial(delayed, index=i), parent=s
)
else:
sample = Sample(delayed(i), parent=s)
new_samples.append(sample)
return new_samples
def transform(self, samples):
......@@ -358,10 +368,17 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
# if the estimator needs to be fitted.
logger.debug(f"{_frmt(self)}.fit")
kwargs = _make_kwargs_from_samples(samples, self.fit_extra_arguments)
# Samples is list of either Sample or DelayedSample created with
# DelayedSamplesCall function, therefore some element in the list can be
# None.
# Filter out invalid samples (i.e. samples[k] == None), otherwise
# SampleBatch will fail and throw exceptions
samples = [
s for s in samples if getattr(s, self.input_attribute) is not None
]
X = SampleBatch(samples, sample_attribute=self.input_attribute)
kwargs = _make_kwargs_from_samples(samples, self.fit_extra_arguments)
self.estimator = self.estimator.fit(X, **kwargs)
copy_learned_attributes(self.estimator, self)
return self
......@@ -903,7 +920,8 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
# change the name to have a better name in dask graphs
_fit.__name__ = f"{_frmt(self)}.fit"
self._dask_state = delayed(_fit)(X, y)
_fit_call = delayed(_fit)(X, y, **fit_params)
self._dask_state = _fit_call.persist()
if self.fit_tag is not None:
# If you do `delayed(_fit)(X, y)`, two tasks are generated;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment