Skip to content
Snippets Groups Projects
Commit 80b35bcb authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira Committed by Amir MOHAMMADI
Browse files

Making dask work again

parent 00f6d036
No related branches found
No related tags found
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #38367 failed
...@@ -24,5 +24,5 @@ from bob.bio.base.pipelines.vanilla_biometrics.mixins import ( ...@@ -24,5 +24,5 @@ from bob.bio.base.pipelines.vanilla_biometrics.mixins import (
BioAlgDaskMixin, BioAlgDaskMixin,
) )
#transformer = estimator_dask_it(transformer) transformer = estimator_dask_it(transformer)
#algorithm = mix_me_up(BioAlgDaskMixin, algorithm) algorithm = mix_me_up(BioAlgDaskMixin, algorithm)
...@@ -87,7 +87,7 @@ class BioAlgCheckpointMixin(CheckpointMixin): ...@@ -87,7 +87,7 @@ class BioAlgCheckpointMixin(CheckpointMixin):
class BioAlgDaskMixin: class BioAlgDaskMixin:
def enroll_samples(self, biometric_reference_features): def enroll_samples(self, biometric_reference_features):
biometric_references = biometric_reference_features.map_partitions( biometric_references = biometric_reference_features.map_partitions(
self.enroll_samples super().enroll_samples
) )
return biometric_references return biometric_references
...@@ -102,5 +102,5 @@ class BioAlgDaskMixin: ...@@ -102,5 +102,5 @@ class BioAlgDaskMixin:
all_references = dask.delayed(list)(biometric_references) all_references = dask.delayed(list)(biometric_references)
scores = probe_features.map_partitions(self.score_samples, all_references) scores = probe_features.map_partitions(super().score_samples, all_references)
return scores return scores
...@@ -183,7 +183,7 @@ def vanilla_biometrics( ...@@ -183,7 +183,7 @@ def vanilla_biometrics(
logger.warning( logger.warning(
"`dask_client` not set. Your pipeline will run locally" "`dask_client` not set. Your pipeline will run locally"
) )
result = result.compute() result = result.compute(scheduler="single-threaded")
# Flatting out the list # Flatting out the list
result = itertools.chain(*result) result = itertools.chain(*result)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment