Commit 22b36d8c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira

Merge branch 'fix-delayed-samples-setattr' into 'master'

[DelayedSample(Set)] make load and delayed_attributes private

See merge request !50
parents e526e5f5 b46f6a39
Pipeline #45995 passed with stages
in 9 minutes and 23 seconds
......@@ -9,20 +9,20 @@ import numpy as np
from bob.io.base import vstack_features
SAMPLE_DATA_ATTRS = ("data", "load", "samples", "delayed_attributes")
SAMPLE_DATA_ATTRS = ("data", "samples")
def _copy_attributes(sample, parent, kwargs):
"""Copies attributes from a dictionary to self."""
if parent is not None:
for key in parent.__dict__:
if key in SAMPLE_DATA_ATTRS:
if key.startswith("_") or key in SAMPLE_DATA_ATTRS:
continue
setattr(sample, key, getattr(parent, key))
for key, value in kwargs.items():
if key in SAMPLE_DATA_ATTRS:
if key.startswith("_") or key in SAMPLE_DATA_ATTRS:
continue
setattr(sample, key, value)
......@@ -33,9 +33,7 @@ class _ReprMixin:
return (
f"{self.__class__.__name__}("
+ ", ".join(
f"{k}={v!r}"
for k, v in self.__dict__.items()
if k != "delayed_attributes"
f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_")
)
+ ")"
)
......@@ -120,24 +118,27 @@ class DelayedSample(_ReprMixin):
"""
def __init__(self, load, parent=None, delayed_attributes=None, **kwargs):
self.delayed_attributes = delayed_attributes
self.__running_init__ = True
self._delayed_attributes = delayed_attributes
# create the delayed attributes but leave the their values as None for now.
if delayed_attributes is not None:
kwargs.update({k: None for k in delayed_attributes})
_copy_attributes(self, parent, kwargs)
self.load = load
self._load = load
del self.__running_init__
def __getattribute__(self, name: str) -> Any:
delayed_attributes = super().__getattribute__("delayed_attributes")
try:
delayed_attributes = super().__getattribute__("_delayed_attributes")
except AttributeError:
delayed_attributes = None
if delayed_attributes is None or name not in delayed_attributes:
return super().__getattribute__(name)
return delayed_attributes[name]()
def __setattr__(self, name: str, value: Any) -> None:
if name != "delayed_attributes" and "__running_init__" not in self.__dict__:
delayed_attributes = self.delayed_attributes
delayed_attributes = getattr(self, "_delayed_attributes", None)
# if setting an attribute which was delayed, remove it from delayed_attributes
if delayed_attributes is not None and name in delayed_attributes:
del delayed_attributes[name]
......@@ -147,7 +148,7 @@ class DelayedSample(_ReprMixin):
@property
def data(self):
"""Loads the data from the disk file."""
return self.load()
return self._load()
class SampleSet(MutableSequence, _ReprMixin):
......@@ -178,12 +179,12 @@ class DelayedSampleSet(SampleSet):
"""A set of samples with extra attributes"""
def __init__(self, load, parent=None, **kwargs):
self.load = load
self._load = load
_copy_attributes(self, parent, kwargs)
@property
def samples(self):
return self.load()
return self._load()
class SampleBatch(Sequence, _ReprMixin):
......
......@@ -101,6 +101,10 @@ def test_delayed_samples():
child_sample = Sample(1, parent=delayed_sample)
assert child_sample.data == 1, child_sample.data
assert child_sample.annot == "annotation", child_sample.annot
assert child_sample.__dict__ == {
"data": 1,
"annot": "annotation",
}, child_sample.__dict__
delayed_sample.annot = "changed"
assert delayed_sample.annot == "changed", delayed_sample.annot
......@@ -11,9 +11,10 @@ from sklearn.preprocessing import FunctionTransformer
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_is_fitted
from bob.pipelines.utils import hash_string
import bob.pipelines as mario
import tempfile
from bob.pipelines.utils import hash_string
def _offset_add_func(X, offset=1):
......@@ -189,7 +190,7 @@ def test_checkpoint_function_sample_transfomer():
features = transformer.transform(samples)
# Checking if we have 8 chars in the second level
assert len(features[0].load.args[0].split("/")[-2]) == 8
assert len(features[0]._load.args[0].split("/")[-2]) == 8
_assert_all_close_numpy_array(oracle, [s.data for s in features])
......
......@@ -39,9 +39,9 @@ def _load_fn_to_xarray(load_fn, meta=None):
def _one_sample_to_dataset(sample, meta=None):
dataset = {}
delayed_attributes = getattr(sample, "delayed_attributes", None) or {}
delayed_attributes = getattr(sample, "_delayed_attributes", None) or {}
for k in sample.__dict__:
if k in SAMPLE_DATA_ATTRS or k in delayed_attributes:
if k in SAMPLE_DATA_ATTRS or k in delayed_attributes or k.startswith("_"):
continue
dataset[k] = getattr(sample, k)
......
......@@ -157,9 +157,9 @@ Below, follow an example on how to use :any:`DelayedSample`.
... return np.zeros((2,))
>>> delayed_sample = mario.DelayedSample(load, metadata=1)
>>> delayed_sample
DelayedSample(metadata=1, load=<function load at ...)
DelayedSample(metadata=1)
As soon as you access the ``.data`` attribute, the data is loaded and kept in memory:
As soon as you access the ``.data`` attribute, the data is loaded and returned:
.. doctest::
......@@ -213,6 +213,6 @@ transform each sample inside and returns the same SampleSets with new data.
>>> transformed_sample_sets = sample_pipeline.transform(sample_sets)
>>> transformed_sample_sets[0].samples[1]
DelayedSample(offset=array([1]), load=...)
DelayedSample(offset=array([1]))
>>> transformed_sample_sets[0].samples[1].data
array([1., 1.])
......@@ -76,7 +76,7 @@ to convert our dataset to a list of samples first:
... for i, y in enumerate(iris.target)
... ]
>>> samples[0]
DelayedSample(target=0, load=...)
DelayedSample(target=0)
You may be already familiar with our sample concept. If not, please read more on
:ref:`bob.pipelines.sample`. Now, to optimize our process, we will represent our
......@@ -265,7 +265,7 @@ features. Let's add the ``key`` metadata to our dataset first:
... for i, y in enumerate(iris.target)
... ]
>>> samples[0]
DelayedSample(target=0, key=0, load=...)
DelayedSample(target=0, key=0)
>>> # construct the meta from one sample
>>> meta = xr.DataArray(samples[0].data, dims=("feature"))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment