Skip to content
Snippets Groups Projects
Commit 71fc43ac authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Load checkpointed estimators inside the scheduler

parent 35ed3631
Branches
Tags
1 merge request!89Load checkpointed estimators inside the scheduler
Pipeline #61050 passed
......@@ -2,6 +2,7 @@
import logging
import os
import tempfile
import time
import traceback
from functools import partial
......@@ -500,7 +501,18 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
and getattr(feat, self.sample_attribute) is not None
):
self.save(feat)
feat = self.load(s, p)
# sometimes loading the file fails randomly
for _ in range(self.attempts):
try:
feat = self.load(s, p)
break
except Exception:
error = traceback.format_exc()
time.sleep(0.1)
else:
raise RuntimeError(
f"Could not load using: {self.load}({s}, {p}) with the following error: {error}"
)
features.append(feat)
else:
features.append(self.load(s, p))
......@@ -586,6 +598,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
break
except Exception:
error = traceback.format_exc()
time.sleep(0.1)
else:
raise RuntimeError(
f"Could not save {to_save} using {self.save_func} with the following error: {error}"
......@@ -608,10 +621,9 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
return self
with open(self.model_path, "rb") as f:
loaded_estimator = cloudpickle.load(f)
estimator = self.estimator
# We update self.estimator instead of replacing it because
# self.estimator might be referenced elsewhere.
_update_estimator(estimator, loaded_estimator)
_update_estimator(self.estimator, loaded_estimator)
return self
def save_model(self):
......@@ -766,7 +778,9 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
# change the name to have a better name in dask graphs
_transf.__name__ = graph_name
map_partitions = X.map_partitions(_transf, self._dask_state)
# scatter the dask_state to all workers for efficiency
dask_state = dask.delayed(self._dask_state)
map_partitions = X.map_partitions(_transf, dask_state)
if self.transform_tag:
self.resource_tags[
tuple(map_partitions.dask.keys())
......@@ -865,11 +879,15 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
logger.info(
f"Checkpointed estimator detected at {model_path}. The estimator ({_frmt(self)}) will be loaded and training will not run."
)
else:
if self.fit_supports_dask_array:
return self._fit_on_dask_array(X, y, **fit_params)
elif self.fit_supports_dask_bag:
return self._fit_on_dask_bag(X, y, **fit_params)
# we should load the estimator outside dask graph to make sure that
# the estimator loads in the scheduler
self.estimator.load_model()
return self
if self.fit_supports_dask_array:
return self._fit_on_dask_array(X, y, **fit_params)
elif self.fit_supports_dask_bag:
return self._fit_on_dask_bag(X, y, **fit_params)
def _fit(X, y, **fit_params):
try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment