Skip to content
Snippets Groups Projects

Fix dask_it, mix_me_up, and CheckpointMixin.load

Merged Amir MOHAMMADI requested to merge updates into master
1 file
+ 24
21
Compare changes
  • Side-by-side
  • Inline
+ 24
21
@@ -11,7 +11,6 @@ from sklearn.base import TransformerMixin
from sklearn.pipeline import Pipeline
from dask import delayed
import dask.bag
import os
def estimator_dask_it(
@@ -19,7 +18,6 @@ def estimator_dask_it(
fit_tag=None,
transform_tag=None,
npartitions=None,
mix_for_each_step_in_pipelines=True,
):
"""
Mix up any :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base` with
@@ -88,9 +86,8 @@ def estimator_dask_it(
# Patching dask_resources
dasked = mix_me_up(
DaskEstimatorMixin,
[DaskEstimatorMixin],
o,
mix_for_each_step_in_pipelines=mix_for_each_step_in_pipelines,
)
# Tagging each element in a pipeline
@@ -98,15 +95,25 @@ def estimator_dask_it(
# Tagging each element for fitting and transforming
if fit_tag is not None:
for t in fit_tag:
o.steps[t[0]][1].fit_tag = t[1]
for index, tag in fit_tag:
o.steps[index][1].fit_tag = tag
else:
for estimator in o.steps:
estimator[1].fit_tag = fit_tag
if transform_tag is not None:
for t in transform_tag:
o.steps[t[0]][1].transform_tag = t[1]
for index, tag in transform_tag:
o.steps[index][1].transform_tag = tag
else:
for estimator in o.steps:
estimator[1].transform_tag = transform_tag
for estimator in o.steps:
estimator[1].resource_tags = dict()
else:
dasked.fit_tag = fit_tag
dasked.transform_tag = transform_tag
dasked.resource_tags = dict()
# Bounding the method
dasked.dask_tags = types.MethodType(_fetch_resource_tape, dasked)
@@ -114,7 +121,7 @@ def estimator_dask_it(
return dasked
def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
def mix_me_up(bases, o):
"""
Dynamically creates a new class from :any:`object` or :any:`class`.
For instance, mix_me_up((A,B), class_c) is equal to `class ABC(A,B,C) pass:`
@@ -149,7 +156,7 @@ def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
"""
def _mix(bases, o):
bases = bases if isinstance(bases, tuple) else tuple([bases])
bases = tuple(bases)
class_name = "".join([c.__name__ for c in bases])
if isinstance(o, types.ClassType):
# If it's a class, just merge them
@@ -158,7 +165,8 @@ def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
else:
# If it's an object, creates a new class and copy the state of the current object
class_name += o.__class__.__name__
new_type = type(class_name, bases + tuple([o.__class__]), o.__dict__)()
new_type = type(class_name, bases + tuple([o.__class__]), o.__dict__)
new_type = new_type.__new__(new_type)
# new_type.__dict__ is made in the descending order of the classes
# so the values of o.__dict__ are overwritten by the lower ones
# here we are copying them back
@@ -168,7 +176,7 @@ def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
# If it is a scikit pipeline, mixIN everything inside of
# Pipeline.steps
if isinstance(o, Pipeline) and mix_for_each_step_in_pipelines:
if isinstance(o, Pipeline):
# mixing all pipelines
for i in range(len(o.steps)):
# checking if it's not the bag transformer
@@ -280,7 +288,7 @@ class CheckpointMixin:
# save the new sample
self.save(new_sample)
else:
new_sample = self.load(path)
new_sample = self.load(sample)
return new_sample
@@ -317,11 +325,6 @@ class CheckpointMixin:
return os.path.join(self.features_dir, str(sample.key) + self.extension)
def recover_key_from_path(self, path):
key = path.replace(os.path.abspath(self.features_dir), "")
key = path[: -len(self.extension)]
return key
def save(self, sample):
if isinstance(sample, Sample):
path = self.make_path(sample)
@@ -335,13 +338,13 @@ class CheckpointMixin:
else:
raise ValueError("Type for sample not supported %s" % type(sample))
def load(self, path):
key = self.recover_key_from_path(path)
def load(self, sample):
path = self.make_path(sample)
# because we are checkpointing, we return a DelayedSample
# instead of a normal (preloaded) sample. This allows the next
# phase to avoid loading it would it be unnecessary (e.g. next
# phase is already check-pointed)
return DelayedSample(functools.partial(self.load_func, path), key=key)
return DelayedSample(functools.partial(self.load_func, path), parent=sample)
def load_model(self):
if _is_estimator_stateless(self):
Loading