Commit fd04003d authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix dask_it, mix_me_up, and CheckpointMixin.load

parent 6736926a
Pipeline #38389 failed with stage
in 3 minutes and 29 seconds
......@@ -11,7 +11,6 @@ from sklearn.base import TransformerMixin
from sklearn.pipeline import Pipeline
from dask import delayed
import dask.bag
import os
def estimator_dask_it(
......@@ -88,25 +87,35 @@ def estimator_dask_it(
# Patching dask_resources
dasked = mix_me_up(
# Tagging each element in a pipeline
if isinstance(o, Pipeline):
if isinstance(o, Pipeline) and mix_for_each_step_in_pipelines:
# Tagging each element for fitting and transforming
if fit_tag is not None:
for t in fit_tag:
o.steps[t[0]][1].fit_tag = t[1]
for index, tag in fit_tag:
o.steps[index][1].fit_tag = tag
for estimator in o.steps:
estimator[1].fit_tag = fit_tag
if transform_tag is not None:
for t in transform_tag:
o.steps[t[0]][1].transform_tag = t[1]
for index, tag in transform_tag:
o.steps[index][1].transform_tag = tag
for estimator in o.steps:
estimator[1].transform_tag = transform_tag
for estimator in o.steps:
estimator.resource_tags = dict()
dasked.fit_tag = fit_tag
dasked.transform_tag = transform_tag
dasked.resource_tags = dict()
# Bounding the method
dasked.dask_tags = types.MethodType(_fetch_resource_tape, dasked)
......@@ -149,7 +158,7 @@ def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
def _mix(bases, o):
bases = bases if isinstance(bases, tuple) else tuple([bases])
bases = tuple(bases)
class_name = "".join([c.__name__ for c in bases])
if isinstance(o, types.ClassType):
# If it's a class, just merge them
......@@ -158,7 +167,8 @@ def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
# If it's an object, creates a new class and copy the state of the current object
class_name += o.__class__.__name__
new_type = type(class_name, bases + tuple([o.__class__]), o.__dict__)()
new_type = type(class_name, bases + tuple([o.__class__]), o.__dict__)
new_type = new_type.__new__(new_type)
# new_type.__dict__ is made in the descending order of the classes
# so the values of o.__dict__ are overwritten by the lower ones
# here we are copying them back
......@@ -280,7 +290,7 @@ class CheckpointMixin:
# save the new sample
new_sample = self.load(path)
new_sample = self.load(sample)
return new_sample
......@@ -317,11 +327,6 @@ class CheckpointMixin:
return os.path.join(self.features_dir, str(sample.key) + self.extension)
def recover_key_from_path(self, path):
key = path.replace(os.path.abspath(self.features_dir), "")
key = path[: -len(self.extension)]
return key
def save(self, sample):
if isinstance(sample, Sample):
path = self.make_path(sample)
......@@ -335,13 +340,13 @@ class CheckpointMixin:
raise ValueError("Type for sample not supported %s" % type(sample))
def load(self, path):
key = self.recover_key_from_path(path)
def load(self, sample):
path = self.make_path(sample)
# because we are checkpointing, we return a DelayedSample
# instead of a normal (preloaded) sample. This allows the next
# phase to avoid loading it would it be unnecessary (e.g. next
# phase is already check-pointed)
return DelayedSample(functools.partial(self.load_func, path), key=key)
return DelayedSample(functools.partial(self.load_func, path), parent=sample)
def load_model(self):
if _is_estimator_stateless(self):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment