Commit 5d008f46 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Remove samples_to_hdf5 methods

parent b7c20eb4
Pipeline #53464 passed with stage
in 18 minutes and 52 seconds
......@@ -8,8 +8,6 @@ from .sample import DelayedSampleSetCached
from .sample import Sample
from .sample import SampleBatch
from .sample import SampleSet
from .sample import hdf5_to_sample # noqa: F401
from .sample import sample_to_hdf5 # noqa: F401
from .wrappers import BaseWrapper
from .wrappers import CheckpointWrapper
from .wrappers import DaskWrapper
......
......@@ -4,7 +4,6 @@ from collections.abc import MutableSequence
from collections.abc import Sequence
from typing import Any
import h5py
import numpy as np
from bob.io.base import vstack_features
......@@ -260,55 +259,3 @@ class SampleBatch(Sequence, _ReprMixin):
arr = vstack_features(_reader, self.samples, dtype=dtype)
return np.asarray(arr, dtype, *args, **kwargs)
def sample_to_hdf5(sample, hdf5):
"""
Saves the content of sample to hdf5 file
Parameters
----------
sample: :any:`Sample` or :any:`DelayedSample` or :any:`list`
Sample to be saved
hdf5: `h5py.File`
Pointer to a HDF5 file for writing
"""
if isinstance(sample, list):
for i, s in enumerate(sample):
group = hdf5.create_group(str(i))
sample_to_hdf5(s, group)
else:
for s in sample.__dict__:
hdf5[s] = getattr(sample, s)
def hdf5_to_sample(hdf5):
"""
Reads the content of a HDF5File and returns a :any:`Sample`
Parameters
----------
hdf5: `h5py.File`
Pointer to a HDF5 file for reading
"""
# Checking if it has groups
has_groups = np.sum([isinstance(hdf5[k], h5py.Group) for k in hdf5.keys()]) > 0
if has_groups:
# If has groups, returns a list of Samples
samples = []
for k in hdf5.keys():
group = hdf5[k]
samples.append(hdf5_to_sample(group))
return samples
else:
# If hasn't groups, returns a sample
sample = Sample(None)
for k in hdf5.keys():
setattr(sample, k, hdf5[k].value)
return sample
......@@ -4,7 +4,6 @@ import os
import pickle
import tempfile
import h5py
import numpy as np
from bob.pipelines import DelayedSample
......@@ -12,8 +11,6 @@ from bob.pipelines import DelayedSampleSet
from bob.pipelines import DelayedSampleSetCached
from bob.pipelines import Sample
from bob.pipelines import SampleSet
from bob.pipelines import hdf5_to_sample
from bob.pipelines import sample_to_hdf5
def test_sampleset_collection():
......@@ -71,36 +68,6 @@ def test_sampleset_collection():
assert sampleset.samples == samples
def test_sample_hdf5():
n_samples = 10
X = np.ones(shape=(n_samples, 2), dtype=int)
samples = [Sample(data, key=str(i), subject="Subject") for i, data in enumerate(X)]
with tempfile.TemporaryDirectory() as dir_name:
# Single sample
filename = os.path.join(dir_name, "sample.hdf5")
with h5py.File(filename, "w", driver="core") as hdf5:
sample_to_hdf5(samples[0], hdf5)
with h5py.File(filename, "r") as hdf5:
sample = hdf5_to_sample(hdf5)
assert sample == samples[0]
# List of samples
filename = os.path.join(dir_name, "samples.hdf5")
with h5py.File(filename, "w", driver="core") as hdf5:
sample_to_hdf5(samples, hdf5)
with h5py.File(filename, "r") as hdf5:
samples_deserialized = hdf5_to_sample(hdf5)
compare = [a == b for a, b in zip(samples_deserialized, samples)]
assert np.sum(compare) == 10
def test_delayed_samples():
def load_data():
return 0
......
......@@ -102,7 +102,7 @@ def _assert_all_close_numpy_array(oracle, result):
def test_sklearn_compatible_estimator():
# check classes for API consistency
check_estimator(DummyWithFit)
check_estimator(DummyWithFit())
def test_function_sample_transfomer():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment