Skip to content
Snippets Groups Projects
Commit 37448bcc authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[sample] Add DelayedSampleSet

parent 78eec2f9
Branches
Tags
1 merge request!35Add DelayedSampleSet, remove bob.pipelines script, add pre-commit
from . import utils
from .sample import Sample, DelayedSample, SampleSet, sample_to_hdf5, hdf5_to_sample
from .sample import Sample, DelayedSample, SampleSet, DelayedSampleSet, sample_to_hdf5, hdf5_to_sample
from .wrappers import (
BaseWrapper,
DelayedSamplesCall,
......@@ -8,7 +8,7 @@ from .wrappers import (
DaskWrapper,
ToDaskBag,
wrap,
dask_tags,
dask_tags,
)
from . import distributed
from . import transformers
......
......@@ -3,7 +3,6 @@
from collections.abc import MutableSequence, Sequence
from .utils import vstack_features
import numpy as np
import os
import h5py
SAMPLE_DATA_ATTRS = ("data", "load", "samples", "_data")
......@@ -11,13 +10,7 @@ SAMPLE_DATA_ATTRS = ("data", "load", "samples", "_data")
def _copy_attributes(s, d):
"""Copies attributes from a dictionary to self."""
s.__dict__.update(
dict(
(k, v)
for k, v in d.items()
if k not in SAMPLE_DATA_ATTRS
)
)
s.__dict__.update(dict((k, v) for k, v in d.items() if k not in SAMPLE_DATA_ATTRS))
class _ReprMixin:
......@@ -128,32 +121,40 @@ class SampleSet(MutableSequence, _ReprMixin):
_copy_attributes(self, parent.__dict__)
_copy_attributes(self, kwargs)
def _load(self):
if isinstance(self.samples, DelayedSample):
self.samples = self.samples.data
def __len__(self):
self._load()
return len(self.samples)
def __getitem__(self, item):
self._load()
return self.samples.__getitem__(item)
def __setitem__(self, key, item):
self._load()
return self.samples.__setitem__(key, item)
def __delitem__(self, item):
self._load()
return self.samples.__delitem__(item)
def insert(self, index, item):
self._load()
# if not item in self.samples:
self.samples.insert(index, item)
class DelayedSampleSet(SampleSet):
"""A set of samples with extra attributes"""
def __init__(self, load, parent=None, **kwargs):
self._data = None
self.load = load
if parent is not None:
_copy_attributes(self, parent.__dict__)
_copy_attributes(self, kwargs)
@property
def samples(self):
if self._data is None:
self._data = self.load()
return self._data
class SampleBatch(Sequence, _ReprMixin):
"""A batch of samples that looks like [s.data for s in samples]
......
from bob.pipelines import (
Sample,
DelayedSample,
SampleSet,
DelayedSampleSet,
sample_to_hdf5,
hdf5_to_sample,
)
import bob.io.base
import numpy as np
import copy
......@@ -44,7 +43,7 @@ def test_sampleset_collection():
def _load(path):
return pickle.loads(open(path, "rb").read())
# Testing delayed sample in the sampleset
# Testing delayed sampleset
with tempfile.TemporaryDirectory() as dir_name:
samples = [Sample(data, key=str(i)) for i, data in enumerate(X)]
......@@ -52,9 +51,10 @@ def test_sampleset_collection():
with open(filename, "wb") as f:
f.write(pickle.dumps(samples))
sampleset = SampleSet(DelayedSample(functools.partial(_load, filename)), key=1)
sampleset = DelayedSampleSet(functools.partial(_load, filename), key=1)
assert len(sampleset) == n_samples
assert sampleset.samples == samples
def test_sample_hdf5():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment