Skip to content
Snippets Groups Projects
Commit 2dc85788 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[DelayedSample(Set)] make load and delayed_attributes private

parent e526e5f5
No related branches found
No related tags found
No related merge requests found
Pipeline #45990 failed
...@@ -9,20 +9,20 @@ import numpy as np ...@@ -9,20 +9,20 @@ import numpy as np
from bob.io.base import vstack_features from bob.io.base import vstack_features
SAMPLE_DATA_ATTRS = ("data", "load", "samples", "delayed_attributes") SAMPLE_DATA_ATTRS = ("data", "samples")
def _copy_attributes(sample, parent, kwargs): def _copy_attributes(sample, parent, kwargs):
"""Copies attributes from a dictionary to self.""" """Copies attributes from a dictionary to self."""
if parent is not None: if parent is not None:
for key in parent.__dict__: for key in parent.__dict__:
if key in SAMPLE_DATA_ATTRS: if key.startswith("__") or key in SAMPLE_DATA_ATTRS:
continue continue
setattr(sample, key, getattr(parent, key)) setattr(sample, key, getattr(parent, key))
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in SAMPLE_DATA_ATTRS: if key.startswith("__") or key in SAMPLE_DATA_ATTRS:
continue continue
setattr(sample, key, value) setattr(sample, key, value)
...@@ -120,24 +120,27 @@ class DelayedSample(_ReprMixin): ...@@ -120,24 +120,27 @@ class DelayedSample(_ReprMixin):
""" """
def __init__(self, load, parent=None, delayed_attributes=None, **kwargs): def __init__(self, load, parent=None, delayed_attributes=None, **kwargs):
self.delayed_attributes = delayed_attributes
self.__running_init__ = True self.__running_init__ = True
self._delayed_attributes = delayed_attributes
# create the delayed attributes but leave the their values as None for now. # create the delayed attributes but leave the their values as None for now.
if delayed_attributes is not None: if delayed_attributes is not None:
kwargs.update({k: None for k in delayed_attributes}) kwargs.update({k: None for k in delayed_attributes})
_copy_attributes(self, parent, kwargs) _copy_attributes(self, parent, kwargs)
self.load = load self._load = load
del self.__running_init__ del self.__running_init__
def __getattribute__(self, name: str) -> Any: def __getattribute__(self, name: str) -> Any:
delayed_attributes = super().__getattribute__("delayed_attributes") try:
delayed_attributes = super().__getattribute__("_delayed_attributes")
except AttributeError:
delayed_attributes = None
if delayed_attributes is None or name not in delayed_attributes: if delayed_attributes is None or name not in delayed_attributes:
return super().__getattribute__(name) return super().__getattribute__(name)
return delayed_attributes[name]() return delayed_attributes[name]()
def __setattr__(self, name: str, value: Any) -> None: def __setattr__(self, name: str, value: Any) -> None:
if name != "delayed_attributes" and "__running_init__" not in self.__dict__: if name != "delayed_attributes" and "__running_init__" not in self.__dict__:
delayed_attributes = self.delayed_attributes delayed_attributes = getattr(self, "_delayed_attributes", None)
# if setting an attribute which was delayed, remove it from delayed_attributes # if setting an attribute which was delayed, remove it from delayed_attributes
if delayed_attributes is not None and name in delayed_attributes: if delayed_attributes is not None and name in delayed_attributes:
del delayed_attributes[name] del delayed_attributes[name]
...@@ -147,7 +150,7 @@ class DelayedSample(_ReprMixin): ...@@ -147,7 +150,7 @@ class DelayedSample(_ReprMixin):
@property @property
def data(self): def data(self):
"""Loads the data from the disk file.""" """Loads the data from the disk file."""
return self.load() return self._load()
class SampleSet(MutableSequence, _ReprMixin): class SampleSet(MutableSequence, _ReprMixin):
...@@ -178,12 +181,12 @@ class DelayedSampleSet(SampleSet): ...@@ -178,12 +181,12 @@ class DelayedSampleSet(SampleSet):
"""A set of samples with extra attributes""" """A set of samples with extra attributes"""
def __init__(self, load, parent=None, **kwargs): def __init__(self, load, parent=None, **kwargs):
self.load = load self._load = load
_copy_attributes(self, parent, kwargs) _copy_attributes(self, parent, kwargs)
@property @property
def samples(self): def samples(self):
return self.load() return self._load()
class SampleBatch(Sequence, _ReprMixin): class SampleBatch(Sequence, _ReprMixin):
......
...@@ -11,9 +11,10 @@ from sklearn.preprocessing import FunctionTransformer ...@@ -11,9 +11,10 @@ from sklearn.preprocessing import FunctionTransformer
from sklearn.utils.estimator_checks import check_estimator from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.validation import check_array from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_is_fitted from sklearn.utils.validation import check_is_fitted
from bob.pipelines.utils import hash_string
import bob.pipelines as mario import bob.pipelines as mario
import tempfile
from bob.pipelines.utils import hash_string
def _offset_add_func(X, offset=1): def _offset_add_func(X, offset=1):
...@@ -189,7 +190,7 @@ def test_checkpoint_function_sample_transfomer(): ...@@ -189,7 +190,7 @@ def test_checkpoint_function_sample_transfomer():
features = transformer.transform(samples) features = transformer.transform(samples)
# Checking if we have 8 chars in the second level # Checking if we have 8 chars in the second level
assert len(features[0].load.args[0].split("/")[-2]) == 8 assert len(features[0]._load.args[0].split("/")[-2]) == 8
_assert_all_close_numpy_array(oracle, [s.data for s in features]) _assert_all_close_numpy_array(oracle, [s.data for s in features])
......
...@@ -39,9 +39,9 @@ def _load_fn_to_xarray(load_fn, meta=None): ...@@ -39,9 +39,9 @@ def _load_fn_to_xarray(load_fn, meta=None):
def _one_sample_to_dataset(sample, meta=None): def _one_sample_to_dataset(sample, meta=None):
dataset = {} dataset = {}
delayed_attributes = getattr(sample, "delayed_attributes", None) or {} delayed_attributes = getattr(sample, "_delayed_attributes", None) or {}
for k in sample.__dict__: for k in sample.__dict__:
if k in SAMPLE_DATA_ATTRS or k in delayed_attributes: if k in SAMPLE_DATA_ATTRS or k in delayed_attributes or k.startswith("_"):
continue continue
dataset[k] = getattr(sample, k) dataset[k] = getattr(sample, k)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment