Skip to content
Snippets Groups Projects

Fix dask_it, mix_me_up, and CheckpointMixin.load

Merged Amir MOHAMMADI requested to merge updates into master
1 file
+ 23
18
Compare changes
  • Side-by-side
  • Inline
+ 23
18
@@ -11,7 +11,6 @@ from sklearn.base import TransformerMixin
@@ -11,7 +11,6 @@ from sklearn.base import TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.pipeline import Pipeline
from dask import delayed
from dask import delayed
import dask.bag
import dask.bag
import os
def estimator_dask_it(
def estimator_dask_it(
@@ -88,25 +87,35 @@ def estimator_dask_it(
@@ -88,25 +87,35 @@ def estimator_dask_it(
# Patching dask_resources
# Patching dask_resources
dasked = mix_me_up(
dasked = mix_me_up(
DaskEstimatorMixin,
[DaskEstimatorMixin],
o,
o,
mix_for_each_step_in_pipelines=mix_for_each_step_in_pipelines,
mix_for_each_step_in_pipelines=mix_for_each_step_in_pipelines,
)
)
# Tagging each element in a pipeline
# Tagging each element in a pipeline
if isinstance(o, Pipeline):
if isinstance(o, Pipeline) and mix_for_each_step_in_pipelines:
# Tagging each element for fitting and transforming
# Tagging each element for fitting and transforming
if fit_tag is not None:
if fit_tag is not None:
for t in fit_tag:
for index, tag in fit_tag:
o.steps[t[0]][1].fit_tag = t[1]
o.steps[index][1].fit_tag = tag
 
else:
 
for estimator in o.steps:
 
estimator[1].fit_tag = fit_tag
if transform_tag is not None:
if transform_tag is not None:
for t in transform_tag:
for index, tag in transform_tag:
o.steps[t[0]][1].transform_tag = t[1]
o.steps[index][1].transform_tag = tag
 
else:
 
for estimator in o.steps:
 
estimator[1].transform_tag = transform_tag
 
 
for estimator in o.steps:
 
estimator.resource_tags = dict()
else:
else:
dasked.fit_tag = fit_tag
dasked.fit_tag = fit_tag
dasked.transform_tag = transform_tag
dasked.transform_tag = transform_tag
 
dasked.resource_tags = dict()
# Bounding the method
# Bounding the method
dasked.dask_tags = types.MethodType(_fetch_resource_tape, dasked)
dasked.dask_tags = types.MethodType(_fetch_resource_tape, dasked)
@@ -149,7 +158,7 @@ def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
@@ -149,7 +158,7 @@ def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
"""
"""
def _mix(bases, o):
def _mix(bases, o):
bases = bases if isinstance(bases, tuple) else tuple([bases])
bases = tuple(bases)
class_name = "".join([c.__name__ for c in bases])
class_name = "".join([c.__name__ for c in bases])
if isinstance(o, types.ClassType):
if isinstance(o, types.ClassType):
# If it's a class, just merge them
# If it's a class, just merge them
@@ -158,7 +167,8 @@ def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
@@ -158,7 +167,8 @@ def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
else:
else:
# If it's an object, creates a new class and copy the state of the current object
# If it's an object, creates a new class and copy the state of the current object
class_name += o.__class__.__name__
class_name += o.__class__.__name__
new_type = type(class_name, bases + tuple([o.__class__]), o.__dict__)()
new_type = type(class_name, bases + tuple([o.__class__]), o.__dict__)
 
new_type = new_type.__new__(new_type)
# new_type.__dict__ is made in the descending order of the classes
# new_type.__dict__ is made in the descending order of the classes
# so the values of o.__dict__ are overwritten by the lower ones
# so the values of o.__dict__ are overwritten by the lower ones
# here we are copying them back
# here we are copying them back
@@ -280,7 +290,7 @@ class CheckpointMixin:
@@ -280,7 +290,7 @@ class CheckpointMixin:
# save the new sample
# save the new sample
self.save(new_sample)
self.save(new_sample)
else:
else:
new_sample = self.load(path)
new_sample = self.load(sample)
return new_sample
return new_sample
@@ -317,11 +327,6 @@ class CheckpointMixin:
@@ -317,11 +327,6 @@ class CheckpointMixin:
return os.path.join(self.features_dir, str(sample.key) + self.extension)
return os.path.join(self.features_dir, str(sample.key) + self.extension)
def recover_key_from_path(self, path):
key = path.replace(os.path.abspath(self.features_dir), "")
key = path[: -len(self.extension)]
return key
def save(self, sample):
def save(self, sample):
if isinstance(sample, Sample):
if isinstance(sample, Sample):
path = self.make_path(sample)
path = self.make_path(sample)
@@ -335,13 +340,13 @@ class CheckpointMixin:
@@ -335,13 +340,13 @@ class CheckpointMixin:
else:
else:
raise ValueError("Type for sample not supported %s" % type(sample))
raise ValueError("Type for sample not supported %s" % type(sample))
def load(self, path):
def load(self, sample):
key = self.recover_key_from_path(path)
path = self.make_path(sample)
# because we are checkpointing, we return a DelayedSample
# because we are checkpointing, we return a DelayedSample
# instead of a normal (preloaded) sample. This allows the next
# instead of a normal (preloaded) sample. This allows the next
# phase to avoid loading it would it be unnecessary (e.g. next
# phase to avoid loading it would it be unnecessary (e.g. next
# phase is already check-pointed)
# phase is already check-pointed)
return DelayedSample(functools.partial(self.load_func, path), key=key)
return DelayedSample(functools.partial(self.load_func, path), parent=sample)
def load_model(self):
def load_model(self):
if _is_estimator_stateless(self):
if _is_estimator_stateless(self):
Loading