Skip to content
Snippets Groups Projects
Commit 16675af9 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'updates' into 'master'

Updates

See merge request !53
parents 99fa97c0 4a3aac01
No related branches found
No related tags found
1 merge request!53Updates
Pipeline #46118 passed
...@@ -3,7 +3,7 @@ from . import transformers # noqa ...@@ -3,7 +3,7 @@ from . import transformers # noqa
from . import utils # noqa from . import utils # noqa
from . import xarray as xr # noqa from . import xarray as xr # noqa
from .sample import DelayedSample from .sample import DelayedSample
from .sample import DelayedSampleSet from .sample import DelayedSampleSet, DelayedSampleSetCached
from .sample import Sample from .sample import Sample
from .sample import SampleSet from .sample import SampleSet
from .sample import SampleBatch from .sample import SampleBatch
......
...@@ -187,6 +187,21 @@ class DelayedSampleSet(SampleSet): ...@@ -187,6 +187,21 @@ class DelayedSampleSet(SampleSet):
return self._load() return self._load()
class DelayedSampleSetCached(DelayedSampleSet):
"""A cached version of DelayedSampleSet"""
def __init__(self, load, parent=None, **kwargs):
super().__init__(load, parent=parent, kwargs=kwargs)
self._data = None
_copy_attributes(self, parent, kwargs)
@property
def samples(self):
if self._data is None:
self._data = self._load()
return self._data
class SampleBatch(Sequence, _ReprMixin): class SampleBatch(Sequence, _ReprMixin):
"""A batch of samples that looks like [s.data for s in samples] """A batch of samples that looks like [s.data for s in samples]
......
...@@ -8,7 +8,7 @@ import h5py ...@@ -8,7 +8,7 @@ import h5py
import numpy as np import numpy as np
from bob.pipelines import DelayedSample from bob.pipelines import DelayedSample
from bob.pipelines import DelayedSampleSet from bob.pipelines import DelayedSampleSet, DelayedSampleSetCached
from bob.pipelines import Sample from bob.pipelines import Sample
from bob.pipelines import SampleSet from bob.pipelines import SampleSet
from bob.pipelines import hdf5_to_sample from bob.pipelines import hdf5_to_sample
...@@ -56,6 +56,19 @@ def test_sampleset_collection(): ...@@ -56,6 +56,19 @@ def test_sampleset_collection():
assert len(sampleset) == n_samples assert len(sampleset) == n_samples
assert sampleset.samples == samples assert sampleset.samples == samples
# Testing delayed sampleset cached
with tempfile.TemporaryDirectory() as dir_name:
samples = [Sample(data, key=str(i)) for i, data in enumerate(X)]
filename = os.path.join(dir_name, "samples.pkl")
with open(filename, "wb") as f:
f.write(pickle.dumps(samples))
sampleset = DelayedSampleSetCached(functools.partial(_load, filename), key=1)
assert len(sampleset) == n_samples
assert sampleset.samples == samples
def test_sample_hdf5(): def test_sample_hdf5():
n_samples = 10 n_samples = 10
......
...@@ -189,8 +189,9 @@ def test_checkpoint_function_sample_transfomer(): ...@@ -189,8 +189,9 @@ def test_checkpoint_function_sample_transfomer():
) )
features = transformer.transform(samples) features = transformer.transform(samples)
# Checking if we have 8 chars in the second level # Checking if we can cast the has as integer
assert len(features[0]._load.args[0].split("/")[-2]) == 8 assert isinstance(int(features[0]._load.args[0].split("/")[-2]), int)
_assert_all_close_numpy_array(oracle, [s.data for s in features]) _assert_all_close_numpy_array(oracle, [s.data for s in features])
......
...@@ -2,8 +2,6 @@ import pickle ...@@ -2,8 +2,6 @@ import pickle
import nose import nose
import numpy as np import numpy as np
import random
import string
def is_picklable(obj): def is_picklable(obj):
...@@ -80,9 +78,10 @@ def isinstance_nested(instance, attribute, isinstance_of): ...@@ -80,9 +78,10 @@ def isinstance_nested(instance, attribute, isinstance_of):
return isinstance_nested(getattr(instance, attribute), attribute, isinstance_of) return isinstance_nested(getattr(instance, attribute), attribute, isinstance_of)
def hash_string(key, bucket_size=1000, word_length=8): def hash_string(key, bucket_size=1000):
""" """
Generates a hash code given a string. Generates a hash code given a string.
The have is given by the `sum(ord([string])) mod bucket_size`
Parameters Parameters
---------- ----------
...@@ -93,18 +92,5 @@ def hash_string(key, bucket_size=1000, word_length=8): ...@@ -93,18 +92,5 @@ def hash_string(key, bucket_size=1000, word_length=8):
bucket_size: int bucket_size: int
Size of the hash table. Size of the hash table.
word_lenth: str
Size of the output string
""" """
letters = string.ascii_lowercase return str(sum([ord(i) for i in (key)]) % bucket_size)
# Getting an integer value from the key
# and mod `n_slots` to have values between 0 and 1000
string_seed = sum([ord(i) for i in (key)]) % bucket_size
# Defining the seed so we have predictable values
random.seed(string_seed)
return "".join(random.choice(letters) for i in range(word_length))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment