Commit 16675af9 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'updates' into 'master'


See merge request !53
parents 99fa97c0 4a3aac01
Pipeline #46118 passed with stages
in 8 minutes and 38 seconds
......@@ -3,7 +3,7 @@ from . import transformers # noqa
from . import utils # noqa
from . import xarray as xr # noqa
from .sample import DelayedSample
from .sample import DelayedSampleSet
from .sample import DelayedSampleSet, DelayedSampleSetCached
from .sample import Sample
from .sample import SampleSet
from .sample import SampleBatch
......@@ -187,6 +187,21 @@ class DelayedSampleSet(SampleSet):
return self._load()
class DelayedSampleSetCached(DelayedSampleSet):
"""A cached version of DelayedSampleSet"""
def __init__(self, load, parent=None, **kwargs):
super().__init__(load, parent=parent, kwargs=kwargs)
self._data = None
_copy_attributes(self, parent, kwargs)
def samples(self):
if self._data is None:
self._data = self._load()
return self._data
class SampleBatch(Sequence, _ReprMixin):
"""A batch of samples that looks like [ for s in samples]
......@@ -8,7 +8,7 @@ import h5py
import numpy as np
from bob.pipelines import DelayedSample
from bob.pipelines import DelayedSampleSet
from bob.pipelines import DelayedSampleSet, DelayedSampleSetCached
from bob.pipelines import Sample
from bob.pipelines import SampleSet
from bob.pipelines import hdf5_to_sample
......@@ -56,6 +56,19 @@ def test_sampleset_collection():
assert len(sampleset) == n_samples
assert sampleset.samples == samples
# Testing delayed sampleset cached
with tempfile.TemporaryDirectory() as dir_name:
samples = [Sample(data, key=str(i)) for i, data in enumerate(X)]
filename = os.path.join(dir_name, "samples.pkl")
with open(filename, "wb") as f:
sampleset = DelayedSampleSetCached(functools.partial(_load, filename), key=1)
assert len(sampleset) == n_samples
assert sampleset.samples == samples
def test_sample_hdf5():
n_samples = 10
......@@ -189,8 +189,9 @@ def test_checkpoint_function_sample_transfomer():
features = transformer.transform(samples)
# Checking if we have 8 chars in the second level
assert len(features[0]._load.args[0].split("/")[-2]) == 8
# Checking if we can cast the has as integer
assert isinstance(int(features[0]._load.args[0].split("/")[-2]), int)
_assert_all_close_numpy_array(oracle, [ for s in features])
......@@ -2,8 +2,6 @@ import pickle
import nose
import numpy as np
import random
import string
def is_picklable(obj):
......@@ -80,9 +78,10 @@ def isinstance_nested(instance, attribute, isinstance_of):
return isinstance_nested(getattr(instance, attribute), attribute, isinstance_of)
def hash_string(key, bucket_size=1000, word_length=8):
def hash_string(key, bucket_size=1000):
Generates a hash code given a string.
The have is given by the `sum(ord([string])) mod bucket_size`
......@@ -93,18 +92,5 @@ def hash_string(key, bucket_size=1000, word_length=8):
bucket_size: int
Size of the hash table.
word_lenth: str
Size of the output string
letters = string.ascii_lowercase
# Getting an integer value from the key
# and mod `n_slots` to have values between 0 and 1000
string_seed = sum([ord(i) for i in (key)]) % bucket_size
# Defining the seed so we have predictable values
return "".join(random.choice(letters) for i in range(word_length))
return str(sum([ord(i) for i in (key)]) % bucket_size)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment