From 80b35bcbf00ffc071a3ca837f73ba380c5066155 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Thu, 26 Mar 2020 11:25:01 +0100 Subject: [PATCH] Making dask work again --- bob/bio/base/config/examples/pca_atnt.py | 4 ++-- bob/bio/base/pipelines/vanilla_biometrics/mixins.py | 4 ++-- bob/bio/base/script/vanilla_biometrics.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bob/bio/base/config/examples/pca_atnt.py b/bob/bio/base/config/examples/pca_atnt.py index 792e6774..afffe2df 100644 --- a/bob/bio/base/config/examples/pca_atnt.py +++ b/bob/bio/base/config/examples/pca_atnt.py @@ -24,5 +24,5 @@ from bob.bio.base.pipelines.vanilla_biometrics.mixins import ( BioAlgDaskMixin, ) -#transformer = estimator_dask_it(transformer) -#algorithm = mix_me_up(BioAlgDaskMixin, algorithm) +transformer = estimator_dask_it(transformer) +algorithm = mix_me_up(BioAlgDaskMixin, algorithm) diff --git a/bob/bio/base/pipelines/vanilla_biometrics/mixins.py b/bob/bio/base/pipelines/vanilla_biometrics/mixins.py index a61f9a59..c335a9aa 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/mixins.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/mixins.py @@ -87,7 +87,7 @@ class BioAlgCheckpointMixin(CheckpointMixin): class BioAlgDaskMixin: def enroll_samples(self, biometric_reference_features): biometric_references = biometric_reference_features.map_partitions( - self.enroll_samples + super().enroll_samples ) return biometric_references @@ -102,5 +102,5 @@ class BioAlgDaskMixin: all_references = dask.delayed(list)(biometric_references) - scores = probe_features.map_partitions(self.score_samples, all_references) + scores = probe_features.map_partitions(super().score_samples, all_references) return scores diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py index 26c9a008..abe07934 100644 --- a/bob/bio/base/script/vanilla_biometrics.py +++ b/bob/bio/base/script/vanilla_biometrics.py @@ -183,7 +183,7 @@ def vanilla_biometrics( logger.warning( "`dask_client` not set. Your pipeline will run locally" ) - result = result.compute() + result = result.compute(scheduler="single-threaded") # Flatting out the list result = itertools.chain(*result) -- GitLab