Commit 4a3aac01 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

[Sample] Created a cached version of DelayedSampleSet

Fixed test case
parent 607bee73
Pipeline #46115 passed with stage
in 3 minutes and 54 seconds
......@@ -3,7 +3,7 @@ from . import transformers # noqa
from . import utils # noqa
from . import xarray as xr # noqa
from .sample import DelayedSample
from .sample import DelayedSampleSet
from .sample import DelayedSampleSet, DelayedSampleSetCached
from .sample import Sample
from .sample import SampleSet
from .sample import SampleBatch
......@@ -187,6 +187,21 @@ class DelayedSampleSet(SampleSet):
return self._load()
class DelayedSampleSetCached(DelayedSampleSet):
"""A cached version of DelayedSampleSet"""
def __init__(self, load, parent=None, **kwargs):
super().__init__(load, parent=parent, kwargs=kwargs)
self._data = None
_copy_attributes(self, parent, kwargs)
def samples(self):
if self._data is None:
self._data = self._load()
return self._data
class SampleBatch(Sequence, _ReprMixin):
"""A batch of samples that looks like [ for s in samples]
......@@ -8,7 +8,7 @@ import h5py
import numpy as np
from bob.pipelines import DelayedSample
from bob.pipelines import DelayedSampleSet
from bob.pipelines import DelayedSampleSet, DelayedSampleSetCached
from bob.pipelines import Sample
from bob.pipelines import SampleSet
from bob.pipelines import hdf5_to_sample
......@@ -56,6 +56,19 @@ def test_sampleset_collection():
assert len(sampleset) == n_samples
assert sampleset.samples == samples
# Testing delayed sampleset cached
with tempfile.TemporaryDirectory() as dir_name:
samples = [Sample(data, key=str(i)) for i, data in enumerate(X)]
filename = os.path.join(dir_name, "samples.pkl")
with open(filename, "wb") as f:
sampleset = DelayedSampleSetCached(functools.partial(_load, filename), key=1)
assert len(sampleset) == n_samples
assert sampleset.samples == samples
def test_sample_hdf5():
n_samples = 10
......@@ -189,8 +189,9 @@ def test_checkpoint_function_sample_transfomer():
features = transformer.transform(samples)
# Checking if we have 8 chars in the second level
assert len(features[0]._load.args[0].split("/")[-2]) == 8
# Checking if we can cast the has as integer
assert isinstance(int(features[0]._load.args[0].split("/")[-2]), int)
_assert_all_close_numpy_array(oracle, [ for s in features])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment