Skip to content
Snippets Groups Projects
Commit d53f32a7 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[DelayedSample(Set)] make load and delayed_attributes private

parent e526e5f5
No related branches found
No related tags found
No related merge requests found
Pipeline #45992 failed
......@@ -9,20 +9,20 @@ import numpy as np
from bob.io.base import vstack_features
SAMPLE_DATA_ATTRS = ("data", "load", "samples", "delayed_attributes")
SAMPLE_DATA_ATTRS = ("data", "samples")
def _copy_attributes(sample, parent, kwargs):
"""Copies attributes from a dictionary to self."""
if parent is not None:
for key in parent.__dict__:
if key in SAMPLE_DATA_ATTRS:
if key.startswith("_") or key in SAMPLE_DATA_ATTRS:
continue
setattr(sample, key, getattr(parent, key))
for key, value in kwargs.items():
if key in SAMPLE_DATA_ATTRS:
if key.startswith("_") or key in SAMPLE_DATA_ATTRS:
continue
setattr(sample, key, value)
......@@ -120,24 +120,27 @@ class DelayedSample(_ReprMixin):
"""
def __init__(self, load, parent=None, delayed_attributes=None, **kwargs):
self.delayed_attributes = delayed_attributes
self.__running_init__ = True
self._delayed_attributes = delayed_attributes
# create the delayed attributes but leave the their values as None for now.
if delayed_attributes is not None:
kwargs.update({k: None for k in delayed_attributes})
_copy_attributes(self, parent, kwargs)
self.load = load
self._load = load
del self.__running_init__
def __getattribute__(self, name: str) -> Any:
delayed_attributes = super().__getattribute__("delayed_attributes")
try:
delayed_attributes = super().__getattribute__("_delayed_attributes")
except AttributeError:
delayed_attributes = None
if delayed_attributes is None or name not in delayed_attributes:
return super().__getattribute__(name)
return delayed_attributes[name]()
def __setattr__(self, name: str, value: Any) -> None:
if name != "delayed_attributes" and "__running_init__" not in self.__dict__:
delayed_attributes = self.delayed_attributes
delayed_attributes = getattr(self, "_delayed_attributes", None)
# if setting an attribute which was delayed, remove it from delayed_attributes
if delayed_attributes is not None and name in delayed_attributes:
del delayed_attributes[name]
......@@ -147,7 +150,7 @@ class DelayedSample(_ReprMixin):
@property
def data(self):
"""Loads the data from the disk file."""
return self.load()
return self._load()
class SampleSet(MutableSequence, _ReprMixin):
......@@ -178,12 +181,12 @@ class DelayedSampleSet(SampleSet):
"""A set of samples with extra attributes"""
def __init__(self, load, parent=None, **kwargs):
self.load = load
self._load = load
_copy_attributes(self, parent, kwargs)
@property
def samples(self):
return self.load()
return self._load()
class SampleBatch(Sequence, _ReprMixin):
......
......@@ -101,6 +101,10 @@ def test_delayed_samples():
child_sample = Sample(1, parent=delayed_sample)
assert child_sample.data == 1, child_sample.data
assert child_sample.annot == "annotation", child_sample.annot
assert child_sample.__dict__ == {
"data": 1,
"annot": "annotation",
}, child_sample.__dict__
delayed_sample.annot = "changed"
assert delayed_sample.annot == "changed", delayed_sample.annot
......@@ -11,9 +11,10 @@ from sklearn.preprocessing import FunctionTransformer
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_is_fitted
from bob.pipelines.utils import hash_string
import bob.pipelines as mario
import tempfile
from bob.pipelines.utils import hash_string
def _offset_add_func(X, offset=1):
......@@ -189,7 +190,7 @@ def test_checkpoint_function_sample_transfomer():
features = transformer.transform(samples)
# Checking if we have 8 chars in the second level
assert len(features[0].load.args[0].split("/")[-2]) == 8
assert len(features[0]._load.args[0].split("/")[-2]) == 8
_assert_all_close_numpy_array(oracle, [s.data for s in features])
......
......@@ -39,9 +39,9 @@ def _load_fn_to_xarray(load_fn, meta=None):
def _one_sample_to_dataset(sample, meta=None):
dataset = {}
delayed_attributes = getattr(sample, "delayed_attributes", None) or {}
delayed_attributes = getattr(sample, "_delayed_attributes", None) or {}
for k in sample.__dict__:
if k in SAMPLE_DATA_ATTRS or k in delayed_attributes:
if k in SAMPLE_DATA_ATTRS or k in delayed_attributes or k.startswith("_"):
continue
dataset[k] = getattr(sample, k)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment