Skip to content
Snippets Groups Projects
Commit f4fcac7f authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira Committed by Amir MOHAMMADI
Browse files

[black]fy

parent 8ce11c18
No related branches found
No related tags found
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
...@@ -17,6 +17,7 @@ import logging ...@@ -17,6 +17,7 @@ import logging
import os import os
import bob.io.base import bob.io.base
import functools import functools
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -28,18 +29,21 @@ def scikit_to_bob_supervised(X, Y): ...@@ -28,18 +29,21 @@ def scikit_to_bob_supervised(X, Y):
""" """
# TODO: THIS IS VERY INNEFICI # TODO: THIS IS VERY INNEFICI
logger.warning("INEFFICIENCY WARNING. HERE YOU ARE USING A HACK FOR USING BOB ALGORITHMS IN SCIKIT LEARN PIPELINES. \ logger.warning(
WE RECOMMEND YOU TO PORT THIS ALGORITHM. DON'T BE LAZY :-)") "INEFFICIENCY WARNING. HERE YOU ARE USING A HACK FOR USING BOB ALGORITHMS IN SCIKIT LEARN PIPELINES. \
WE RECOMMEND YOU TO PORT THIS ALGORITHM. DON'T BE LAZY :-)"
)
bob_output = dict() bob_output = dict()
for x,y in zip(X, Y): for x, y in zip(X, Y):
if y in bob_output: if y in bob_output:
bob_output[y] = numpy.vstack((bob_output[y], x.data)) bob_output[y] = numpy.vstack((bob_output[y], x.data))
else: else:
bob_output[y] = x.data bob_output[y] = x.data
return [bob_output[k] for k in bob_output] return [bob_output[k] for k in bob_output]
class LegacyProcessorMixin(TransformerMixin): class LegacyProcessorMixin(TransformerMixin):
"""Class that wraps :py:class:`bob.bio.base.preprocessor.Preprocessor` and """Class that wraps :py:class:`bob.bio.base.preprocessor.Preprocessor` and
:py:class:`bob.bio.base.extractor.Extractors` :py:class:`bob.bio.base.extractor.Extractors`
...@@ -88,7 +92,9 @@ class LegacyProcessorMixin(TransformerMixin): ...@@ -88,7 +92,9 @@ class LegacyProcessorMixin(TransformerMixin):
from bob.pipelines.mixins import CheckpointMixin, SampleMixin from bob.pipelines.mixins import CheckpointMixin, SampleMixin
class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
class LegacyAlgorithmMixin(CheckpointMixin, SampleMixin, BaseEstimator):
"""Class that wraps :py:class:`bob.bio.base.algorithm.Algoritm` """Class that wraps :py:class:`bob.bio.base.algorithm.Algoritm`
:py:method:`LegacyAlgorithmrMixin.fit` maps to :py:method:`bob.bio.base.algorithm.Algoritm.train_projector` :py:method:`LegacyAlgorithmrMixin.fit` maps to :py:method:`bob.bio.base.algorithm.Algoritm.train_projector`
...@@ -122,9 +128,8 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator): ...@@ -122,9 +128,8 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
self.instance = None self.instance = None
self.projector_file = None self.projector_file = None
def fit(self, X, y=None, **fit_params): def fit(self, X, y=None, **fit_params):
self.projector_file = os.path.join(self.model_path, "Projector.hdf5") self.projector_file = os.path.join(self.model_path, "Projector.hdf5")
if os.path.exists(self.projector_file): if os.path.exists(self.projector_file):
return self return self
...@@ -146,12 +151,11 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator): ...@@ -146,12 +151,11 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
return self return self
def transform(self, X): def transform(self, X):
def _project_save_sample(sample): def _project_save_sample(sample):
# Project # Project
projected_data = self.instance.project(sample.data) projected_data = self.instance.project(sample.data)
#Checkpointing # Checkpointing
path = self.make_path(sample) path = self.make_path(sample)
bob.io.base.create_directories_safe(os.path.dirname(path)) bob.io.base.create_directories_safe(os.path.dirname(path))
f = bob.io.base.HDF5File(path, "w") f = bob.io.base.HDF5File(path, "w")
...@@ -191,11 +195,11 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator): ...@@ -191,11 +195,11 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
def get_reader(reader, path): def get_reader(reader, path):
if(is_picklable(reader)): if is_picklable(reader):
return functools.partial(reader, path) return functools.partial(reader, path)
else: else:
logger.warning( logger.warning(
f"The method {reader} is not picklable. Shiping its unbounded method to `DelayedSample`." f"The method {reader} is not picklable. Shiping its unbounded method to `DelayedSample`."
) )
reader = reader.__func__ # The reader object might not be picklable reader = reader.__func__ # The reader object might not be picklable
return functools.partial(reader, None, path) return functools.partial(reader, None, path)
...@@ -21,6 +21,7 @@ class SamplePCA(SampleMixin, PCA): ...@@ -21,6 +21,7 @@ class SamplePCA(SampleMixin, PCA):
""" """
Enables SAMPLE handling for https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html Enables SAMPLE handling for https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
""" """
pass pass
...@@ -28,4 +29,5 @@ class CheckpointSamplePCA(CheckpointMixin, SampleMixin, PCA): ...@@ -28,4 +29,5 @@ class CheckpointSamplePCA(CheckpointMixin, SampleMixin, PCA):
""" """
Enables SAMPLE and CHECKPOINTIN handling for https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html Enables SAMPLE and CHECKPOINTIN handling for https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
""" """
pass pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment