Commit 73c2e695 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

[DelayedSample] Fix issues when an attribute was set

parent 13c6d8d2
Pipeline #45973 passed with stage
in 3 minutes and 46 seconds
......@@ -2,6 +2,7 @@
from collections.abc import MutableSequence
from collections.abc import Sequence
from typing import Any
import h5py
import numpy as np
......@@ -120,18 +121,29 @@ class DelayedSample(_ReprMixin):
def __init__(self, load, parent=None, delayed_attributes=None, **kwargs):
self.delayed_attributes = delayed_attributes
self.__running_init__ = True
# create the delayed attributes but leave the their values as None for now.
if delayed_attributes is not None:
kwargs.update({k: None for k in delayed_attributes})
_copy_attributes(self, parent, kwargs)
self.load = load
del self.__running_init__
def __getattribute__(self, name: str):
def __getattribute__(self, name: str) -> Any:
delayed_attributes = super().__getattribute__("delayed_attributes")
if delayed_attributes is None or name not in delayed_attributes:
return super().__getattribute__(name)
return delayed_attributes[name]()
def __setattr__(self, name: str, value: Any) -> None:
if name != "delayed_attributes" and "__running_init__" not in self.__dict__:
delayed_attributes = self.delayed_attributes
# if setting an attribute which was delayed, remove it from delayed_attributes
if delayed_attributes is not None and name in delayed_attributes:
del delayed_attributes[name]
super().__setattr__(name, value)
@property
def data(self):
"""Loads the data from the disk file."""
......@@ -166,15 +178,12 @@ class DelayedSampleSet(SampleSet):
"""A set of samples with extra attributes"""
def __init__(self, load, parent=None, **kwargs):
self._data = None
self.load = load
_copy_attributes(self, parent, kwargs)
@property
def samples(self):
if self._data is None:
self._data = self.load()
return self._data
return self.load()
class SampleBatch(Sequence, _ReprMixin):
......@@ -222,7 +231,7 @@ def sample_to_hdf5(sample, hdf5):
sample_to_hdf5(s, group)
else:
for s in sample.__dict__:
hdf5[s] = sample.__dict__[s]
hdf5[s] = getattr(sample, s)
def hdf5_to_sample(hdf5):
......@@ -250,6 +259,6 @@ def hdf5_to_sample(hdf5):
# If hasn't groups, returns a sample
sample = Sample(None)
for k in hdf5.keys():
sample.__dict__[k] = hdf5[k].value
setattr(sample, k, hdf5[k].value)
return sample
......@@ -101,3 +101,6 @@ def test_delayed_samples():
child_sample = Sample(1, parent=delayed_sample)
assert child_sample.data == 1, child_sample.data
assert child_sample.annot == "annotation", child_sample.annot
delayed_sample.annot = "changed"
assert delayed_sample.annot == "changed", delayed_sample.annot
......@@ -4,7 +4,6 @@ import os
from functools import partial
import bob.io.base
import cloudpickle
import dask.bag
......@@ -15,6 +14,8 @@ from sklearn.base import TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
import bob.io.base
from .sample import DelayedSample
from .sample import SampleBatch
from .sample import SampleSet
......@@ -137,13 +138,14 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
if isinstance(samples[0], SampleSet):
return [
SampleSet(
self._samples_transform(sset.samples, method_name), parent=sset,
self._samples_transform(sset.samples, method_name),
parent=sset,
)
for sset in samples
]
else:
kwargs = _make_kwargs_from_samples(samples, self.transform_extra_arguments)
delayed = DelayedSamplesCall(partial(method, **kwargs), func_name, samples,)
delayed = DelayedSamplesCall(partial(method, **kwargs), func_name, samples)
if self.output_attribute != "data":
# Edit the sample.<output_attribute> instead of data
for i, s in enumerate(samples):
......@@ -202,13 +204,13 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
estimator
The scikit-learn estimator to be wrapped.
model_path: str
Saves the estimator state in this directory if the `estimator` is stateful
features_dir: str
Saves the transformed data in this directory
extension: str
Default extension of the transformed features
......@@ -216,14 +218,14 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
Pointer to a customized function that saves transformed features to disk
load_func
Pointer to a customized function that loads transformed features from disk
Pointer to a customized function that loads transformed features from disk
sample_attribute: str
Defines the payload attribute of the sample (Defaul: `data`)
hash_fn
Pointer to a hash function. This hash function maps
`sample.key` to a hash code and this hash code corresponds a relative directory
`sample.key` to a hash code and this hash code corresponds a relative directory
where a single `sample` will be checkpointed.
This is useful when is desirable file directories with less than
a certain number of files.
......@@ -406,7 +408,11 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
"""
def __init__(
self, estimator, fit_tag=None, transform_tag=None, **kwargs,
self,
estimator,
fit_tag=None,
transform_tag=None,
**kwargs,
):
super().__init__(**kwargs)
self.estimator = estimator
......@@ -458,7 +464,10 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
# change the name to have a better name in dask graphs
_fit.__name__ = f"{_frmt(self)}.fit"
self._dask_state = delayed(_fit)(X, y,)
self._dask_state = delayed(_fit)(
X,
y,
)
if self.fit_tag is not None:
self.resource_tags[self._dask_state] = self.fit_tag
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment