Serializing Samples

parent a31bbd39
Pipeline #40035 passed with stage
in 8 minutes and 58 seconds
......@@ -3,6 +3,18 @@
from collections.abc import MutableSequence, Sequence
from .utils import vstack_features
import numpy as np
from distributed.protocol.serialize import (
serialize,
deserialize,
dask_serialize,
dask_deserialize,
register_generic,
)
import cloudpickle
import logging
logger = logging.getLogger(__name__)
def _copy_attributes(s, d):
......@@ -89,6 +101,15 @@ class DelayedSample(_ReprMixin):
self._data = self.load()
return self._data
#def __getstate__(self):
# d = dict(self.__dict__)
# d.pop("_data", None)
# return d
#def __setstate__(self, d):
# self._data = d.pop("_data", None)
# self.__dict__.update(d)
class SampleSet(MutableSequence, _ReprMixin):
"""A set of samples with extra attributes"""
......@@ -99,7 +120,6 @@ class SampleSet(MutableSequence, _ReprMixin):
_copy_attributes(self, parent.__dict__)
_copy_attributes(self, kwargs)
def _load(self):
if isinstance(self.samples, DelayedSample):
self.samples = self.samples.data
......@@ -146,5 +166,107 @@ class SampleBatch(Sequence, _ReprMixin):
def _reader(s):
# adding one more dimension to data so they get stacked sample-wise
return s.data[None, ...]
arr = vstack_features(_reader, self.samples, dtype=dtype)
return np.asarray(arr, dtype, *args, **kwargs)
def get_serialized_sample_header(sample):
sample_header = dict(
(k, v)
for k, v in sample.__dict__.items()
if k not in ("data", "load", "samples", "_data")
)
return cloudpickle.dumps(sample_header)
def deserialize_sample_header(sample):
return cloudpickle.loads(sample)
@dask_serialize.register(SampleSet)
def serialize_sampleset(sampleset):
def serialize_delayed_sample(delayed_sample):
header_sample = get_serialized_sample_header(delayed_sample)
frame_sample = cloudpickle.dumps(delayed_sample)
return header_sample, frame_sample
header = dict()
# Ship the header of the sampleset
# in the header of the message
key = sampleset.key
header["sampleset_header"] = get_serialized_sample_header(sampleset)
header["sample_header"] = []
frames = []
# Checking first if our sampleset.samples are shipped as DelayedSample
if isinstance(sampleset.samples, DelayedSample):
header_sample, frame_sample = serialize_delayed_sample(sampleset.samples)
frames += [frame_sample]
header["sample_header"].append(header_sample)
header["sample_type"] = "DelayedSampleList"
else:
for sample in sampleset.samples:
if isinstance(sample, DelayedSample):
header_sample, frame_sample = serialize_delayed_sample(sample)
frame_sample = [frame_sample]
else:
header_sample, frame_sample = serialize(sample)
frames += frame_sample
header["sample_header"].append(header_sample)
header["sample_type"] = "DelayedSample" if isinstance(sample, DelayedSample) else "Sample"
return header, frames
@dask_deserialize.register(SampleSet)
def deserialize_sampleset(header, frames):
if not "sample_header" in header:
raise ValueError("Problem with SampleSet serialization. `_sample_header` not found")
sampleset_header = deserialize_sample_header(header["sampleset_header"])
sampleset = SampleSet([], **sampleset_header)
if header["sample_type"]=="DelayedSampleList":
sampleset.samples = cloudpickle.loads(frames[0])
return sampleset
for h, f in zip(header["sample_header"], frames):
if header["sample_type"] == "Sample":
data = dask_deserialize.dispatch(Sample)(h, [f])
sampleset.samples.append(data)
else:
sampleset.samples.append( cloudpickle.loads(f) )
return sampleset
@dask_serialize.register(Sample)
def serialize_sample(sample):
header_sample = get_serialized_sample_header(sample)
# If data is numpy array, uses the dask serializer
header, frames = serialize(sample.data)
header["sample"] = header_sample
return header, frames
@dask_deserialize.register(Sample)
def deserialize_sample(header, frames):
try:
data = dask_deserialize.dispatch(np.ndarray)(header, frames)
except KeyError:
data = cloudpickle.loads(frames)
sample_header = deserialize_sample_header(header["sample"])
sample = Sample(data, parent=None, **sample_header)
return sample
import bob.pipelines as mario
import numpy
from bob.pipelines import Sample, SampleSet, DelayedSample
import numpy as np
from distributed.protocol.serialize import serialize,deserialize
import copy
import pickle
import msgpack
import tempfile
import functools
import os
......@@ -10,14 +12,14 @@ import os
def test_sampleset_collection():
n_samples = 10
X = numpy.ones(shape=(n_samples, 2), dtype=int)
sampleset = mario.SampleSet(
[mario.Sample(data, key=str(i)) for i, data in enumerate(X)], key="1"
X = np.ones(shape=(n_samples, 2), dtype=int)
sampleset = SampleSet(
[Sample(data, key=str(i)) for i, data in enumerate(X)], key="1"
)
assert len(sampleset) == n_samples
# Testing insert
sample = mario.Sample(X, key=100)
sample = Sample(X, key=100)
sampleset.insert(1, sample)
assert len(sampleset) == n_samples + 1
......@@ -30,7 +32,7 @@ def test_sampleset_collection():
# Testing iterator
for i in sampleset:
assert isinstance(i, mario.Sample)
assert isinstance(i, Sample)
def _load(path):
......@@ -39,11 +41,114 @@ def test_sampleset_collection():
# Testing delayed sample in the sampleset
with tempfile.TemporaryDirectory() as dir_name:
samples = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)]
samples = [Sample(data, key=str(i)) for i, data in enumerate(X)]
filename = os.path.join(dir_name, "samples.pkl")
with open(filename, "wb") as f:
f.write(pickle.dumps(samples))
sampleset = mario.SampleSet(mario.DelayedSample(functools.partial(_load, filename)), key=1)
sampleset = SampleSet(DelayedSample(functools.partial(_load, filename)), key=1)
assert len(sampleset)==n_samples
\ No newline at end of file
assert len(sampleset)==n_samples
def test_sample_serialization():
sample = Sample(np.random.rand(1, 1, 2), key=1)
header, frame = serialize(sample)
deserialized_sample = deserialize(header, frame)
assert isinstance(deserialized_sample, Sample)
# Testing serialization Sampleset
sample = Sample(np.random.rand(1, 1, 2), key=1)
sampleset = SampleSet([sample], key=1)
header, frame = serialize(sampleset)
deserialized_sampleset = deserialize(header, frame)
assert isinstance(deserialized_sampleset, SampleSet)
deserialized_sampleset[0] = Sample(np.random.rand(3, 480, 400), key=1)
# serialize again
header, frame = serialize(deserialized_sampleset)
deserialized_sampleset = deserialize(header, frame)
assert isinstance(deserialized_sampleset, SampleSet)
# Testing list serialization
header, frame = serialize([deserialized_sampleset])
deserialized_sampleset = deserialize(header, frame)
assert isinstance(deserialized_sampleset, list)
assert isinstance(deserialized_sampleset[0], SampleSet)
def test_sample_serialization_scale():
def create_samplesets(n_sample_sets, n_samples):
return [
SampleSet(
[Sample(data=np.random.rand(20, 1,)) for _ in range(n_samples)],
key=i,
references=list(range(1000))
)
for i in range(n_sample_sets)
]
samplesets = create_samplesets(10, 10)
header, frame = serialize(samplesets)
# header needs to be serializable with msgpack
msgpack.dumps(header)
deserialized_samplesets = deserialize(header, frame)
assert isinstance(deserialized_samplesets, list)
assert isinstance(deserialized_samplesets[0], SampleSet)
def test_sample_serialization_delayed():
with tempfile.TemporaryDirectory() as dir_name:
def create_samplesets(n_sample_sets, n_samples, as_list=False):
samples = [Sample(data=np.random.rand(20, 1,)) for _ in range(n_samples)]
filename = os.path.join(dir_name, "xuxa.pkl")
open(filename, "wb").write(pickle.dumps(samples))
def _load(path):
return pickle.loads(open(path, "rb").read())
if as_list:
delayed_samples = [DelayedSample(functools.partial(_load, filename), key=1, references=list(range(1000)) )]
else:
delayed_samples = DelayedSample(functools.partial(_load, filename), key=1, references=np.array(list(range(1000)), dtype="float") )
return [
SampleSet(
delayed_samples,
key=i,
references=np.array(list(range(1000)), dtype="float")
)
for i in range(n_sample_sets)
]
samplesets = create_samplesets(1, 10, as_list=False)
header, frame = serialize(samplesets)
# header needs to be serializable with msgpack
msgpack.dumps(header)
deserialized_samplesets = deserialize(header, frame)
assert isinstance(deserialized_samplesets, list)
assert isinstance(deserialized_samplesets[0], SampleSet)
# Testing list of samplesets
samplesets = create_samplesets(1, 10, as_list=True)
header, frame = serialize(samplesets)
# header needs to be serializable with msgpack
msgpack.dumps(header)
deserialized_samplesets = deserialize(header, frame)
assert isinstance(deserialized_samplesets, list)
assert isinstance(deserialized_samplesets[0], SampleSet)
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment