diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py index 115c5ac910756dd07e3c0a3369014ffabe6d5940..157c0def30ece6b451796f480306906f821526ad 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py @@ -5,9 +5,10 @@ import dask import functools from .score_writers import FourColumnsScoreWriter from .abstract_classes import BioAlgorithm -import pickle import bob.pipelines as mario import numpy as np +import h5py +import cloudpickle from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper @@ -45,6 +46,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): self.biometric_algorithm = biometric_algorithm self.force = force self._biometric_reference_extension = ".hdf5" + self._score_extension = ".pkl" self.base_dir = base_dir def enroll(self, enroll_features): @@ -63,7 +65,9 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): def write_scores(self, samples, path): os.makedirs(os.path.dirname(path), exist_ok=True) - open(path, "wb").write(pickle.dumps(samples)) + # cleaning parent + with open(path, "wb") as f: + f.write(cloudpickle.dumps(samples)) def _enroll_sample_set(self, sampleset): """ @@ -99,17 +103,21 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): """ def _load(path): - return pickle.loads(open(path, "rb").read()) + return cloudpickle.loads(open(path, "rb").read()) + + #with h5py.File(path) as hdf5: + # return hdf5_to_sample(hdf5) def _make_name(sampleset, biometric_references): # The score file name is composed by sampleset key and the # first 3 biometric_references + subject = str(sampleset.subject) name = str(sampleset.key) suffix = "_".join([str(s.key) for s in biometric_references[0:3]]) - return name + suffix + return os.path.join(subject, name + suffix) path = os.path.join( - self.score_dir, _make_name(sampleset, biometric_references) + ".pkl" + self.score_dir, _make_name(sampleset, biometric_references) + self._score_extension ) if self.force or not os.path.exists(path): @@ -123,11 +131,12 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): self.write_scores(scored_sample_set.samples, path) scored_sample_set = SampleSet( - [DelayedSample(functools.partial(_load, path), parent=sampleset)], + DelayedSample(functools.partial(_load, path), parent=sampleset), parent=sampleset, ) else: - scored_sample_set = SampleSet(_load(path), parent=sampleset) + samples = _load(path) + scored_sample_set = SampleSet(samples, parent=sampleset) return scored_sample_set @@ -182,7 +191,7 @@ class BioAlgorithmDaskWrapper(BioAlgorithm): ) -def dask_vanilla_biometrics(pipeline, npartitions=None): +def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): """ Given a :any:`VanillaBiometrics`, wraps :any:`VanillaBiometrics.transformer` and :any:`VanillaBiometrics.biometric_algorithm` to be executed with dask @@ -195,6 +204,9 @@ def dask_vanilla_biometrics(pipeline, npartitions=None): npartitions: int Number of partitions for the initial :any:`dask.bag` + + partition_size: int + Size of the partition for the initial :any:`dask.bag` """ if isinstance(pipeline, ZTNormPipeline): @@ -207,9 +219,14 @@ def dask_vanilla_biometrics(pipeline, npartitions=None): else: - pipeline.transformer = mario.wrap( - ["dask"], pipeline.transformer, npartitions=npartitions - ) + if partition_size is None: + pipeline.transformer = mario.wrap( + ["dask"], pipeline.transformer, npartitions=npartitions + ) + else: + pipeline.transformer = mario.wrap( + ["dask"], pipeline.transformer, partition_size=partition_size + ) pipeline.biometric_algorithm = BioAlgorithmDaskWrapper( pipeline.biometric_algorithm )