Skip to content
Snippets Groups Projects

Fix fit extra parameters

Merged Yannick DAYER requested to merge fix-fit-params into master
All threads resolved!
+ 39
20
@@ -123,6 +123,9 @@ def get_bob_tags(estimator=None, force_tags=None):
Indicates that the fit method of that estimator accepts dask arrays as input.
Default:
`{"bob_fit_supports_dask_array": False}`
bob_fit_expects_samplesets: bool
Indicates that the fit method of that estimator accepts groups of samples as
input.
Parameters
----------
@@ -149,6 +152,7 @@ def get_bob_tags(estimator=None, force_tags=None):
"bob_features_save_fn": bob.io.base.save,
"bob_features_load_fn": bob.io.base.load,
"bob_fit_supports_dask_array": False,
"bob_fit_expects_samplesets": False,
}
estimator_tags = estimator._get_tags() if estimator is not None else {}
return {**default_tags, **estimator_tags, **force_tags}
@@ -616,10 +620,12 @@ def _len_samples(samples):
def _shape_samples(samples):
return [[s.shape for s in samples]]
return [[s.shape if hasattr(s, "shape") else (1,) for s in samples]]
def _array_from_sample_bags(X: dask.bag.Bag, attribute: str):
def _array_from_sample_bags(
X: dask.bag.Bag, attribute: str, stack_output: bool = True
):
# because samples could be delayed samples, we convert sample bags to
# sample.attribute bags first and then persist
@@ -655,7 +661,8 @@ def _array_from_sample_bags(X: dask.bag.Bag, attribute: str):
X.append(darray)
# stack data from all bags
X = da.vstack(X)
if stack_output:
X = da.vstack(X)
return X
@@ -665,17 +672,23 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
Parameters
----------
fit_resource: str
Mark the delayed(self.fit) with this value. This can be used in
a future delayed(self.fit).compute(resources=resource_tape) so
dask scheduler can place this task in a particular resource
(e.g GPU)
transform_resource: str
Mark the delayed(self.transform) with this value. This can be used in
a future delayed(self.transform).compute(resources=resource_tape) so
dask scheduler can place this task in a particular resource
(e.g GPU)
fit_resource: str
Mark the delayed(self.fit) with this value. This can be used in
a future delayed(self.fit).compute(resources=resource_tape) so
dask scheduler can place this task in a particular resource
(e.g GPU)
transform_resource: str
Mark the delayed(self.transform) with this value. This can be used in
a future delayed(self.transform).compute(resources=resource_tape) so
dask scheduler can place this task in a particular resource
(e.g GPU)
fit_supports_dask_array: bool
Whether the fit method supports dask arrays as input.
fit_expects_samplesets: bool
Whether the fit method expects grouped data (not stacked) as input.
"""
def __init__(
@@ -684,6 +697,7 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
fit_tag=None,
transform_tag=None,
fit_supports_dask_array=None,
fit_expects_samplesets=None,
**kwargs,
):
super().__init__(**kwargs)
@@ -692,9 +706,12 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
self.resource_tags = dict()
self.fit_tag = fit_tag
self.transform_tag = transform_tag
bob_tags = get_bob_tags(self.estimator)
self.fit_supports_dask_array = (
fit_supports_dask_array
or get_bob_tags(self.estimator)["bob_fit_supports_dask_array"]
fit_supports_dask_array or bob_tags["bob_fit_supports_dask_array"]
)
self.fit_expects_samplesets = (
fit_expects_samplesets or bob_tags["bob_fit_expects_samplesets"]
)
def _make_dask_resource_tag(self, tag):
@@ -732,17 +749,17 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
def score(self, samples):
return self._dask_transform(samples, "score")
def _get_fit_params_from_sample_bags(self, bags):
def _get_fit_params_from_sample_bags(self, bags, stack_output=True):
logger.debug("Converting dask bag to dask array")
input_attribute = getattr_nested(self, "input_attribute")
fit_extra_arguments = getattr_nested(self, "fit_extra_arguments")
# convert X which is a dask bag to a dask array
X = _array_from_sample_bags(bags, input_attribute)
X = _array_from_sample_bags(bags, input_attribute, stack_output)
kwargs = dict()
for arg, attr in fit_extra_arguments:
kwargs[arg] = _array_from_sample_bags(bags, attr)
kwargs[arg] = _array_from_sample_bags(bags, attr, stack_output)
return X, kwargs
def fit(self, X, y=None, **fit_params):
@@ -762,7 +779,9 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
model_path = getattr_nested(self, "model_path")
model_path = model_path or ""
if not os.path.isfile(model_path):
X, fit_params = self._get_fit_params_from_sample_bags(X)
X, fit_params = self._get_fit_params_from_sample_bags(
X, not self.fit_expects_samplesets
)
# the estimators are supposed to be dask (self) | [checkpoint] | sample | estimator
estimator = self.estimator.estimator
if is_checkpointed(self):
Loading