Skip to content
Snippets Groups Projects
Commit 0414d3be authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Polishing score serialization. It's impossible to do it in HDF5

parent d9f5534b
No related branches found
No related tags found
2 merge requests!188Score normalizations,!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #40060 passed
......@@ -5,9 +5,10 @@ import dask
import functools
from .score_writers import FourColumnsScoreWriter
from .abstract_classes import BioAlgorithm
import pickle
import bob.pipelines as mario
import numpy as np
import h5py
import cloudpickle
from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper
......@@ -45,6 +46,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
self.biometric_algorithm = biometric_algorithm
self.force = force
self._biometric_reference_extension = ".hdf5"
self._score_extension = ".pkl"
self.base_dir = base_dir
def enroll(self, enroll_features):
......@@ -63,7 +65,9 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
def write_scores(self, samples, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
open(path, "wb").write(pickle.dumps(samples))
# cleaning parent
with open(path, "wb") as f:
f.write(cloudpickle.dumps(samples))
def _enroll_sample_set(self, sampleset):
"""
......@@ -99,17 +103,21 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
"""
def _load(path):
return pickle.loads(open(path, "rb").read())
return cloudpickle.loads(open(path, "rb").read())
#with h5py.File(path) as hdf5:
# return hdf5_to_sample(hdf5)
def _make_name(sampleset, biometric_references):
# The score file name is composed by sampleset key and the
# first 3 biometric_references
subject = str(sampleset.subject)
name = str(sampleset.key)
suffix = "_".join([str(s.key) for s in biometric_references[0:3]])
return name + suffix
return os.path.join(subject, name + suffix)
path = os.path.join(
self.score_dir, _make_name(sampleset, biometric_references) + ".pkl"
self.score_dir, _make_name(sampleset, biometric_references) + self._score_extension
)
if self.force or not os.path.exists(path):
......@@ -123,11 +131,12 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
self.write_scores(scored_sample_set.samples, path)
scored_sample_set = SampleSet(
[DelayedSample(functools.partial(_load, path), parent=sampleset)],
DelayedSample(functools.partial(_load, path), parent=sampleset),
parent=sampleset,
)
else:
scored_sample_set = SampleSet(_load(path), parent=sampleset)
samples = _load(path)
scored_sample_set = SampleSet(samples, parent=sampleset)
return scored_sample_set
......@@ -182,7 +191,7 @@ class BioAlgorithmDaskWrapper(BioAlgorithm):
)
def dask_vanilla_biometrics(pipeline, npartitions=None):
def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None):
"""
Given a :any:`VanillaBiometrics`, wraps :any:`VanillaBiometrics.transformer` and
:any:`VanillaBiometrics.biometric_algorithm` to be executed with dask
......@@ -195,6 +204,9 @@ def dask_vanilla_biometrics(pipeline, npartitions=None):
npartitions: int
Number of partitions for the initial :any:`dask.bag`
partition_size: int
Size of the partition for the initial :any:`dask.bag`
"""
if isinstance(pipeline, ZTNormPipeline):
......@@ -207,9 +219,14 @@ def dask_vanilla_biometrics(pipeline, npartitions=None):
else:
pipeline.transformer = mario.wrap(
["dask"], pipeline.transformer, npartitions=npartitions
)
if partition_size is None:
pipeline.transformer = mario.wrap(
["dask"], pipeline.transformer, npartitions=npartitions
)
else:
pipeline.transformer = mario.wrap(
["dask"], pipeline.transformer, partition_size=partition_size
)
pipeline.biometric_algorithm = BioAlgorithmDaskWrapper(
pipeline.biometric_algorithm
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment