Commit 9a3d419a authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Improvements on CheckpointWrapper

parent 6ea25543
Pipeline #45905 passed with stage
in 8 minutes and 12 seconds
......@@ -11,7 +11,7 @@ from sklearn.preprocessing import FunctionTransformer
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_is_fitted
from bob.pipelines.utils import hash_string
import bob.pipelines as mario
......@@ -175,6 +175,17 @@ def test_checkpoint_function_sample_transfomer():
features = transformer.transform(samples)
_assert_all_close_numpy_array(oracle, [s.data for s in features])
# test when both model_path and features_dir is None
transformer = mario.wrap(
[FunctionTransformer, "sample", "checkpoint"],
func=_offset_add_func,
kw_args=dict(offset=offset),
validate=True,
hash_fn=hash_string,
)
features = transformer.transform(samples)
_assert_all_close_numpy_array(oracle, [s.data for s in features])
def test_checkpoint_fittable_sample_transformer():
X = np.ones(shape=(10, 2), dtype=int)
......
......@@ -2,6 +2,8 @@ import pickle
import nose
import numpy as np
import random
import string
def is_picklable(obj):
......@@ -76,3 +78,33 @@ def isinstance_nested(instance, attribute, isinstance_of):
else:
# Recursive search
return isinstance_nested(getattr(instance, attribute), attribute, isinstance_of)
def hash_string(key, bucket_size=1000, word_length=8):
"""
Generates a hash code given a string.
Parameters
----------
key: str
Input string to be hashed
bucket_size: int
Size of the hash table.
word_lenth: str
Size of the output string
"""
letters = string.ascii_lowercase
# Getting an integer value from the key
# and mod `n_slots` to have values between 0 and 1000
string_seed = sum([ord(i) for i in (key)]) % bucket_size
# Defining the seed so we have predictable values
random.seed(string_seed)
return "".join(random.choice(letters) for i in range(word_length))
......@@ -137,8 +137,7 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
if isinstance(samples[0], SampleSet):
return [
SampleSet(
self._samples_transform(sset.samples, method_name),
parent=sset,
self._samples_transform(sset.samples, method_name), parent=sset,
)
for sset in samples
]
......@@ -201,9 +200,34 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
Parameters
----------
estimator
The scikit-learn estimator to be wrapped.
model_path: str
Saves the estimator state in this directory if the `estimator` is stateful
features_dir: str
Saves the transformed data in this directory
extension: str
Default extension of the transformed features
save_func
Pointer to a customized function that saves transformed features to disk
load_func
Pointer to a customized function that loads transformed features from disk
sample_attribute: str
The attribute of the Sample object that needs to be saved to disk.
[Default is ``data``].
Defines the payload attribute of the sample (Defaul: `data`)
hash_fn
Pointer to a hash function. This hash function maps
`sample.key` to a hash code and this hash code corresponds a relative directory
where a single `sample` will be checkpointed.
This is useful when is desirable file directories with less than
a certain number of files.
"""
def __init__(
......@@ -215,6 +239,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
save_func=None,
load_func=None,
sample_attribute="data",
hash_fn=None,
**kwargs,
):
super().__init__(**kwargs)
......@@ -225,6 +250,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
self.save_func = save_func or bob.io.base.save
self.load_func = load_func or bob.io.base.load
self.sample_attribute = sample_attribute
self.hash_fn = hash_fn
if model_path is None and features_dir is None:
logger.warning(
"Both model_path and features_dir are None. "
......@@ -306,6 +332,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
return self.save_model()
def make_path(self, sample):
if self.features_dir is None:
return None
......@@ -315,7 +342,10 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
"Sample.key values should be relative paths with no "
f"reference to upper folders. Got: {key}"
)
return os.path.join(self.features_dir, key + self.extension)
hash_dir_name = self.hash_fn(key) if self.hash_fn is not None else ""
return os.path.join(self.features_dir, hash_dir_name, key + self.extension)
def save(self, sample):
path = self.make_path(sample)
......@@ -376,11 +406,7 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
"""
def __init__(
self,
estimator,
fit_tag=None,
transform_tag=None,
**kwargs,
self, estimator, fit_tag=None, transform_tag=None, **kwargs,
):
super().__init__(**kwargs)
self.estimator = estimator
......@@ -432,10 +458,7 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
# change the name to have a better name in dask graphs
_fit.__name__ = f"{_frmt(self)}.fit"
self._dask_state = delayed(_fit)(
X,
y,
)
self._dask_state = delayed(_fit)(X, y,)
if self.fit_tag is not None:
self.resource_tags[self._dask_state] = self.fit_tag
......@@ -528,8 +551,9 @@ def wrap(bases, estimator=None, **kwargs):
# when checkpointing a pipeline, checkpoint each transformer in its own folder
new_kwargs = dict(kwargs)
features_dir, model_path = kwargs.get("features_dir"), kwargs.get(
"model_path"
features_dir, model_path = (
kwargs.get("features_dir"),
kwargs.get("model_path"),
)
if features_dir is not None:
new_kwargs["features_dir"] = os.path.join(features_dir, name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment