Commit 639cb0c6 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add support for fitting estimators on dask bags

The estimators that can handle dask bags should set
the `bob_fit_supports_dask_bag` as True.
This commit also includes
* Adds a new tag: `bob_fit_supports_dask_bag`
* Adds a new tag: `bob_checkpoint_features` for when you want to always avoid
  checkpointing features for a specific estimator.
* Expose dask_tags, get_bob_tags in the main API
* The SampleWrapper was modified to support `bob_fit_supports_dask_bag`
* The CheckpointWrapper now loads estimators without losing references correctly.
parent fc80d8d5
Pipeline #61012 failed with stage
in 12 minutes and 14 seconds
......@@ -10,15 +10,16 @@ from .sample import (
SampleBatch,
SampleSet,
)
from .wrappers import dask_tags # noqa: F401
from .wrappers import wrap # noqa: F401
from .wrappers import (
from .wrappers import ( # noqa: F401
BaseWrapper,
CheckpointWrapper,
DaskWrapper,
DelayedSamplesCall,
SampleWrapper,
ToDaskBag,
dask_tags,
get_bob_tags,
)
......
......@@ -121,11 +121,22 @@ def get_bob_tags(estimator=None, force_tags=None):
Default:
`{"bob_features_load_fn": bob.io.base.load}`
bob_fit_supports_dask_array: bool
Indicates that the fit method of that estimator accepts dask arrays as input.
You may only use this tag if you accept X (N, M) and optionally y (N) as input.
The fit function may not accept any other input.
Indicates that the fit method of that estimator accepts dask arrays as
input. You may only use this tag if you accept X (N, M) and optionally y
(N) as input. The fit function may not accept any other input.
Default:
`{"bob_fit_supports_dask_array": False}`
bob_fit_supports_dask_bag: bool
Indicates that the fit method of that estimator accepts dask bags as
input. If true, each input parameter of the fit will be a dask bag. You
still can (and normally you should) wrap your estimator with the
SampleWrapper so the same code runs with and without dask.
Default:
`{"bob_fit_supports_dask_bag": False}`
bob_checkpoint_features: bool
If False, the features of the estimator will never be saved.
Default:
`{"bob_checkpoint_features": True}`
Parameters
----------
......@@ -152,6 +163,8 @@ def get_bob_tags(estimator=None, force_tags=None):
"bob_features_save_fn": bob.io.base.save,
"bob_features_load_fn": bob.io.base.load,
"bob_fit_supports_dask_array": False,
"bob_fit_supports_dask_bag": False,
"bob_checkpoint_features": True,
}
estimator_tags = estimator._get_tags() if estimator is not None else {}
return {**default_tags, **estimator_tags, **force_tags}
......@@ -317,7 +330,14 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
def score(self, samples):
return self._samples_transform(samples, "score")
def fit(self, samples, y=None):
def fit(self, samples, y=None, **kwargs):
# If samples is a dask bag, pass the arguments unmodified
# The data is already prepared in the DaskWrapper
if isinstance(samples, dask.bag.core.Bag):
logger.debug(f"{_frmt(self)}.fit")
self.estimator.fit(samples, y, **kwargs)
return self
if y is not None:
raise TypeError(
"We don't accept `y` in fit arguments because `y` should be part of "
......@@ -407,6 +427,19 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
**kwargs,
):
super().__init__(**kwargs)
bob_tags = get_bob_tags(estimator)
self.extension = extension or bob_tags["bob_checkpoint_extension"]
self.save_func = save_func or bob_tags["bob_features_save_fn"]
self.load_func = load_func or bob_tags["bob_features_load_fn"]
self.sample_attribute = sample_attribute or bob_tags["bob_output"]
if not bob_tags["bob_checkpoint_features"]:
logger.info(
"Checkpointing is disabled for %s beacuse the bob_checkpoint_features tag is False.",
estimator,
)
features_dir = None
self.force = force
self.estimator = estimator
self.model_path = model_path
......@@ -414,12 +447,6 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
self.hash_fn = hash_fn
self.attempts = attempts
bob_tags = get_bob_tags(self.estimator)
self.extension = extension or bob_tags["bob_checkpoint_extension"]
self.save_func = save_func or bob_tags["bob_features_save_fn"]
self.load_func = load_func or bob_tags["bob_features_load_fn"]
self.sample_attribute = sample_attribute or bob_tags["bob_output"]
# Paths check
if model_path is None and features_dir is None:
logger.warning(
......@@ -578,17 +605,9 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
with open(self.model_path, "rb") as f:
loaded_estimator = cloudpickle.load(f)
estimator = self.estimator
# For the update, ensure that we have the estimator, not a wrapper
while hasattr(estimator, "estimator"):
# Update this estimator __dict__ except for the attribute "estimator"
for k, v in loaded_estimator.__dict__.items():
if k != "estimator":
estimator.__dict__[k] = v
estimator = estimator.estimator
loaded_estimator = loaded_estimator.estimator
# we don't do self.estimator = loaded_estimator, because self.estimator
# might be used elsewhere
estimator.__dict__.update(loaded_estimator.__dict__)
# We update self.estimator instead of replacing it because
# self.estimator might be referenced elsewhere.
_update_estimator(estimator, loaded_estimator)
return self
def save_model(self):
......@@ -600,6 +619,16 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
return self
def _update_estimator(estimator, loaded_estimator):
# recursively update estimator with loaded_estimator without replacing
# estimator.estimator
if hasattr(estimator, "estimator"):
_update_estimator(estimator.estimator, loaded_estimator.estimator)
for k, v in loaded_estimator.__dict__.items():
if k != "estimator":
estimator.__dict__[k] = v
def is_checkpointed(estimator):
return isinstance_nested(estimator, "estimator", CheckpointWrapper)
......@@ -703,6 +732,7 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
fit_tag=None,
transform_tag=None,
fit_supports_dask_array=None,
fit_supports_dask_bag=None,
**kwargs,
):
super().__init__(**kwargs)
......@@ -715,6 +745,10 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
fit_supports_dask_array
or get_bob_tags(self.estimator)["bob_fit_supports_dask_array"]
)
self.fit_supports_dask_bag = (
fit_supports_dask_bag
or get_bob_tags(self.estimator)["bob_fit_supports_dask_bag"]
)
def _make_dask_resource_tag(self, tag):
return {tag: 1}
......@@ -752,7 +786,7 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
return self._dask_transform(samples, "score")
def _get_fit_params_from_sample_bags(self, bags):
logger.debug("Converting dask bag to dask array")
logger.debug("Preparing data as dask arrays for fit")
input_attribute = getattr_nested(self, "input_attribute")
fit_extra_arguments = getattr_nested(self, "fit_extra_arguments")
......@@ -771,37 +805,67 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
return X, kwargs
def _fit_on_dask_array(self, bags, y=None, **fit_params):
if y is not None or fit_params:
raise ValueError(
"y or fit_params should be passed through fit_extra_arguments of the SampleWrapper"
)
X, fit_params = self._get_fit_params_from_sample_bags(bags)
# the estimators are supposed to be dask (self) | [checkpoint] | sample | estimator
estimator = self.estimator.estimator
if is_checkpointed(self):
estimator = estimator.estimator
estimator.fit(X, **fit_params)
# if the estimator is checkpointed, we need to save the model
if is_checkpointed(self):
self.estimator.save_model()
return self
def _fit_on_dask_bag(self, bags, y=None, **fit_params):
# X is a dask bag of Samples convert to required fit parameters
logger.debug("Converting dask bag of samples to bags of fit parameters")
def getattr_list(samples, attribute):
return SampleBatch(samples, sample_attribute=attribute)
# we prepare the input parameters here instead of doing this in the
# SampleWrapper. The SampleWrapper class then will pass these dask bags
# directly to the underlying estimator.
bob_tags = get_bob_tags(self.estimator)
input_attribute = bob_tags["bob_input"]
fit_extra_arguments = bob_tags["bob_fit_extra_input"]
X = bags.map_partitions(getattr_list, input_attribute)
kwargs = {
arg: bags.map_partitions(getattr_list, attr)
for arg, attr in fit_extra_arguments
}
self.estimator.fit(X, **kwargs)
# if the estimator is checkpointed, we need to save the model
if is_checkpointed(self):
self.estimator.save_model()
return self
def fit(self, X, y=None, **fit_params):
if is_estimator_stateless(self.estimator):
return self
logger.debug(f"{_frmt(self)}.fit")
if self.fit_supports_dask_array:
if y is not None or fit_params:
raise ValueError(
"y or fit_params should be passed through fit_extra_arguments of the SampleWrapper"
)
model_path = None
if is_checkpointed(self):
model_path = getattr_nested(self, "model_path")
model_path = model_path or ""
if not os.path.isfile(model_path):
X, fit_params = self._get_fit_params_from_sample_bags(X)
# the estimators are supposed to be dask (self) | [checkpoint] | sample | estimator
estimator = self.estimator.estimator
if is_checkpointed(self):
estimator = estimator.estimator
estimator.fit(X, **fit_params)
# if the estimator is checkpointed, we need to save the model
if is_checkpointed(self):
self.estimator.save_model()
return self
else:
logger.info(
f"Ignoring conversion to dask array (checkpoint detected at {model_path})"
)
model_path = None
if is_checkpointed(self):
model_path = getattr_nested(self, "model_path")
model_path = model_path or ""
if os.path.isfile(model_path):
logger.info(
f"Checkpointed estimator detected at {model_path}. The estimator will be loaded and training will not run."
)
else:
if self.fit_supports_dask_array:
return self._fit_on_dask_array(X, y, **fit_params)
elif self.fit_supports_dask_bag:
return self._fit_on_dask_bag(X, y, **fit_params)
def _fit(X, y, **fit_params):
try:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment