diff --git a/bob/pipelines/sample.py b/bob/pipelines/sample.py index 8318d1e837bb8ac01630aa2c23e4525e8c88cfbb..2b3693fe4b5fa9504d853b92345c6655b5cca3b4 100644 --- a/bob/pipelines/sample.py +++ b/bob/pipelines/sample.py @@ -3,6 +3,18 @@ from collections.abc import MutableSequence, Sequence from .utils import vstack_features import numpy as np +from distributed.protocol.serialize import ( + serialize, + deserialize, + dask_serialize, + dask_deserialize, + register_generic, +) +import cloudpickle + +import logging + +logger = logging.getLogger(__name__) def _copy_attributes(s, d): @@ -89,6 +101,15 @@ class DelayedSample(_ReprMixin): self._data = self.load() return self._data + #def __getstate__(self): + # d = dict(self.__dict__) + # d.pop("_data", None) + # return d + + #def __setstate__(self, d): + # self._data = d.pop("_data", None) + # self.__dict__.update(d) + class SampleSet(MutableSequence, _ReprMixin): """A set of samples with extra attributes""" @@ -99,7 +120,6 @@ class SampleSet(MutableSequence, _ReprMixin): _copy_attributes(self, parent.__dict__) _copy_attributes(self, kwargs) - def _load(self): if isinstance(self.samples, DelayedSample): self.samples = self.samples.data @@ -146,5 +166,107 @@ class SampleBatch(Sequence, _ReprMixin): def _reader(s): # adding one more dimension to data so they get stacked sample-wise return s.data[None, ...] + arr = vstack_features(_reader, self.samples, dtype=dtype) return np.asarray(arr, dtype, *args, **kwargs) + + +def get_serialized_sample_header(sample): + + sample_header = dict( + (k, v) + for k, v in sample.__dict__.items() + if k not in ("data", "load", "samples", "_data") + ) + + return cloudpickle.dumps(sample_header) + +def deserialize_sample_header(sample): + return cloudpickle.loads(sample) + +@dask_serialize.register(SampleSet) +def serialize_sampleset(sampleset): + + def serialize_delayed_sample(delayed_sample): + header_sample = get_serialized_sample_header(delayed_sample) + frame_sample = cloudpickle.dumps(delayed_sample) + return header_sample, frame_sample + + header = dict() + + # Ship the header of the sampleset + # in the header of the message + key = sampleset.key + header["sampleset_header"] = get_serialized_sample_header(sampleset) + + header["sample_header"] = [] + frames = [] + + # Checking first if our sampleset.samples are shipped as DelayedSample + if isinstance(sampleset.samples, DelayedSample): + header_sample, frame_sample = serialize_delayed_sample(sampleset.samples) + frames += [frame_sample] + header["sample_header"].append(header_sample) + header["sample_type"] = "DelayedSampleList" + else: + for sample in sampleset.samples: + if isinstance(sample, DelayedSample): + header_sample, frame_sample = serialize_delayed_sample(sample) + frame_sample = [frame_sample] + else: + header_sample, frame_sample = serialize(sample) + frames += frame_sample + header["sample_header"].append(header_sample) + + header["sample_type"] = "DelayedSample" if isinstance(sample, DelayedSample) else "Sample" + + return header, frames + + +@dask_deserialize.register(SampleSet) +def deserialize_sampleset(header, frames): + + if not "sample_header" in header: + raise ValueError("Problem with SampleSet serialization. `_sample_header` not found") + + sampleset_header = deserialize_sample_header(header["sampleset_header"]) + sampleset = SampleSet([], **sampleset_header) + + + if header["sample_type"]=="DelayedSampleList": + sampleset.samples = cloudpickle.loads(frames[0]) + return sampleset + + for h, f in zip(header["sample_header"], frames): + if header["sample_type"] == "Sample": + data = dask_deserialize.dispatch(Sample)(h, [f]) + sampleset.samples.append(data) + else: + sampleset.samples.append( cloudpickle.loads(f) ) + + return sampleset + + +@dask_serialize.register(Sample) +def serialize_sample(sample): + + header_sample = get_serialized_sample_header(sample) + + # If data is numpy array, uses the dask serializer + header, frames = serialize(sample.data) + header["sample"] = header_sample + + return header, frames + + +@dask_deserialize.register(Sample) +def deserialize_sample(header, frames): + + try: + data = dask_deserialize.dispatch(np.ndarray)(header, frames) + except KeyError: + data = cloudpickle.loads(frames) + + sample_header = deserialize_sample_header(header["sample"]) + sample = Sample(data, parent=None, **sample_header) + return sample diff --git a/bob/pipelines/tests/test_samples.py b/bob/pipelines/tests/test_samples.py index d622b70fdc29128c2635cea98ff318039e8825d3..d6ebdfec7e71877eac9b241cfd430e61569448b1 100644 --- a/bob/pipelines/tests/test_samples.py +++ b/bob/pipelines/tests/test_samples.py @@ -1,8 +1,10 @@ -import bob.pipelines as mario -import numpy +from bob.pipelines import Sample, SampleSet, DelayedSample +import numpy as np +from distributed.protocol.serialize import serialize,deserialize import copy import pickle +import msgpack import tempfile import functools import os @@ -10,14 +12,14 @@ import os def test_sampleset_collection(): n_samples = 10 - X = numpy.ones(shape=(n_samples, 2), dtype=int) - sampleset = mario.SampleSet( - [mario.Sample(data, key=str(i)) for i, data in enumerate(X)], key="1" + X = np.ones(shape=(n_samples, 2), dtype=int) + sampleset = SampleSet( + [Sample(data, key=str(i)) for i, data in enumerate(X)], key="1" ) assert len(sampleset) == n_samples # Testing insert - sample = mario.Sample(X, key=100) + sample = Sample(X, key=100) sampleset.insert(1, sample) assert len(sampleset) == n_samples + 1 @@ -30,7 +32,7 @@ def test_sampleset_collection(): # Testing iterator for i in sampleset: - assert isinstance(i, mario.Sample) + assert isinstance(i, Sample) def _load(path): @@ -39,11 +41,114 @@ def test_sampleset_collection(): # Testing delayed sample in the sampleset with tempfile.TemporaryDirectory() as dir_name: - samples = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)] + samples = [Sample(data, key=str(i)) for i, data in enumerate(X)] filename = os.path.join(dir_name, "samples.pkl") with open(filename, "wb") as f: f.write(pickle.dumps(samples)) - sampleset = mario.SampleSet(mario.DelayedSample(functools.partial(_load, filename)), key=1) + sampleset = SampleSet(DelayedSample(functools.partial(_load, filename)), key=1) - assert len(sampleset)==n_samples \ No newline at end of file + assert len(sampleset)==n_samples + + +def test_sample_serialization(): + sample = Sample(np.random.rand(1, 1, 2), key=1) + header, frame = serialize(sample) + deserialized_sample = deserialize(header, frame) + assert isinstance(deserialized_sample, Sample) + + + # Testing serialization Sampleset + sample = Sample(np.random.rand(1, 1, 2), key=1) + sampleset = SampleSet([sample], key=1) + header, frame = serialize(sampleset) + deserialized_sampleset = deserialize(header, frame) + + assert isinstance(deserialized_sampleset, SampleSet) + deserialized_sampleset[0] = Sample(np.random.rand(3, 480, 400), key=1) + + # serialize again + header, frame = serialize(deserialized_sampleset) + deserialized_sampleset = deserialize(header, frame) + assert isinstance(deserialized_sampleset, SampleSet) + + # Testing list serialization + header, frame = serialize([deserialized_sampleset]) + deserialized_sampleset = deserialize(header, frame) + assert isinstance(deserialized_sampleset, list) + assert isinstance(deserialized_sampleset[0], SampleSet) + + +def test_sample_serialization_scale(): + + def create_samplesets(n_sample_sets, n_samples): + return [ + SampleSet( + [Sample(data=np.random.rand(20, 1,)) for _ in range(n_samples)], + key=i, + references=list(range(1000)) + ) + for i in range(n_sample_sets) + ] + + samplesets = create_samplesets(10, 10) + header, frame = serialize(samplesets) + + # header needs to be serializable with msgpack + msgpack.dumps(header) + + deserialized_samplesets = deserialize(header, frame) + assert isinstance(deserialized_samplesets, list) + assert isinstance(deserialized_samplesets[0], SampleSet) + + +def test_sample_serialization_delayed(): + + with tempfile.TemporaryDirectory() as dir_name: + + + def create_samplesets(n_sample_sets, n_samples, as_list=False): + + samples = [Sample(data=np.random.rand(20, 1,)) for _ in range(n_samples)] + filename = os.path.join(dir_name, "xuxa.pkl") + open(filename, "wb").write(pickle.dumps(samples)) + + def _load(path): + return pickle.loads(open(path, "rb").read()) + + if as_list: + delayed_samples = [DelayedSample(functools.partial(_load, filename), key=1, references=list(range(1000)) )] + else: + delayed_samples = DelayedSample(functools.partial(_load, filename), key=1, references=np.array(list(range(1000)), dtype="float") ) + + return [ + SampleSet( + delayed_samples, + key=i, + references=np.array(list(range(1000)), dtype="float") + ) + for i in range(n_sample_sets) + ] + + samplesets = create_samplesets(1, 10, as_list=False) + header, frame = serialize(samplesets) + # header needs to be serializable with msgpack + msgpack.dumps(header) + + deserialized_samplesets = deserialize(header, frame) + + assert isinstance(deserialized_samplesets, list) + assert isinstance(deserialized_samplesets[0], SampleSet) + + + # Testing list of samplesets + samplesets = create_samplesets(1, 10, as_list=True) + header, frame = serialize(samplesets) + + # header needs to be serializable with msgpack + msgpack.dumps(header) + + deserialized_samplesets = deserialize(header, frame) + assert isinstance(deserialized_samplesets, list) + assert isinstance(deserialized_samplesets[0], SampleSet) + \ No newline at end of file