Commit 98832088 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[DelayedSample] Handle attribute copy of delayed attributes

parent e0901ee2
Pipeline #45910 passed with stage
in 7 minutes and 47 seconds
......@@ -8,19 +8,34 @@ import numpy as np
from bob.io.base import vstack_features
SAMPLE_DATA_ATTRS = ("data", "load", "samples", "_data")
SAMPLE_DATA_ATTRS = ("data", "load", "samples", "delayed_attributes")
def _copy_attributes(s, d):
def _copy_attributes(sample, parent, kwargs):
"""Copies attributes from a dictionary to self."""
s.__dict__.update(dict((k, v) for k, v in d.items() if k not in SAMPLE_DATA_ATTRS))
if parent is not None:
for key in parent.__dict__:
if key in SAMPLE_DATA_ATTRS:
continue
setattr(sample, key, getattr(parent, key))
for key, value in kwargs.items():
if key in SAMPLE_DATA_ATTRS:
continue
setattr(sample, key, value)
class _ReprMixin:
def __repr__(self):
return (
f"{self.__class__.__name__}("
+ ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())
+ ", ".join(
f"{k}={v!r}"
for k, v in self.__dict__.items()
if k != "delayed_attributes"
)
+ ")"
)
......@@ -72,9 +87,7 @@ class Sample(_ReprMixin):
def __init__(self, data, parent=None, **kwargs):
self.data = data
if parent is not None:
_copy_attributes(self, parent.__dict__)
_copy_attributes(self, kwargs)
_copy_attributes(self, parent, kwargs)
class DelayedSample(_ReprMixin):
......@@ -106,22 +119,18 @@ class DelayedSample(_ReprMixin):
"""
def __init__(self, load, parent=None, delayed_attributes=None, **kwargs):
if parent is not None:
_copy_attributes(self, parent.__dict__)
_copy_attributes(self, kwargs)
self.load = load
self.delayed_attributes = delayed_attributes
# create the delayed attributes but leave the their values as None for now.
if delayed_attributes is not None:
kwargs.update({k: None for k in delayed_attributes})
_copy_attributes(self, parent, kwargs)
self.load = load
def __getattr__(self, name: str):
if self.delayed_attributes is None:
raise AttributeError(name)
load_fn = self.delayed_attributes.get(name)
if load_fn is None:
raise AttributeError(name)
return load_fn()
def __getattribute__(self, name: str):
delayed_attributes = super().__getattribute__("delayed_attributes")
if delayed_attributes is None or name not in delayed_attributes:
return super().__getattribute__(name)
return delayed_attributes[name]()
@property
def data(self):
......@@ -134,9 +143,7 @@ class SampleSet(MutableSequence, _ReprMixin):
def __init__(self, samples, parent=None, **kwargs):
self.samples = samples
if parent is not None:
_copy_attributes(self, parent.__dict__)
_copy_attributes(self, kwargs)
_copy_attributes(self, parent, kwargs)
def __len__(self):
return len(self.samples)
......@@ -161,9 +168,7 @@ class DelayedSampleSet(SampleSet):
def __init__(self, load, parent=None, **kwargs):
self._data = None
self.load = load
if parent is not None:
_copy_attributes(self, parent.__dict__)
_copy_attributes(self, kwargs)
_copy_attributes(self, parent, kwargs)
@property
def samples(self):
......
......@@ -7,6 +7,7 @@ import tempfile
import h5py
import numpy as np
from bob.pipelines import DelayedSample
from bob.pipelines import DelayedSampleSet
from bob.pipelines import Sample
from bob.pipelines import SampleSet
......@@ -84,3 +85,19 @@ def test_sample_hdf5():
compare = [a == b for a, b in zip(samples_deserialized, samples)]
assert np.sum(compare) == 10
def test_delayed_samples():
def load_data():
return 0
def load_annot():
return "annotation"
delayed_sample = DelayedSample(load_data, delayed_attributes=dict(annot=load_annot))
assert delayed_sample.data == 0, delayed_sample.data
assert delayed_sample.annot == "annotation", delayed_sample.annot
child_sample = Sample(1, parent=delayed_sample)
assert child_sample.data == 1, child_sample.data
assert child_sample.annot == "annotation", child_sample.annot
......@@ -157,7 +157,7 @@ Below, follow an example on how to use :any:`DelayedSample`.
... return np.zeros((2,))
>>> delayed_sample = mario.DelayedSample(load, metadata=1)
>>> delayed_sample
DelayedSample(load=<function load at ...>, metadata=1, _data=None)
DelayedSample(metadata=1, load=<function load at ...)
As soon as you access the ``.data`` attribute, the data is loaded and kept in memory:
......@@ -213,6 +213,6 @@ transform each sample inside and returns the same SampleSets with new data.
>>> transformed_sample_sets = sample_pipeline.transform(sample_sets)
>>> transformed_sample_sets[0].samples[1]
DelayedSample(load=..., offset=array([1]), _data=None)
DelayedSample(offset=array([1]), load=...)
>>> transformed_sample_sets[0].samples[1].data
array([1., 1.])
......@@ -76,7 +76,7 @@ to convert our dataset to a list of samples first:
... for i, y in enumerate(iris.target)
... ]
>>> samples[0]
DelayedSample(load=functools.partial(<function load at ...>, 0), target=0, _data=None)
DelayedSample(target=0, load=...)
You may be already familiar with our sample concept. If not, please read more on
:ref:`bob.pipelines.sample`. Now, to optimize our process, we will represent our
......@@ -265,7 +265,7 @@ features. Let's add the ``key`` metadata to our dataset first:
... for i, y in enumerate(iris.target)
... ]
>>> samples[0]
DelayedSample(load=functools.partial(<function load at ...>, 0), target=0, key=0, _data=None)
DelayedSample(target=0, key=0, load=...)
>>> # construct the meta from one sample
>>> meta = xr.DataArray(samples[0].data, dims=("feature"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment