Skip to content
Snippets Groups Projects
Commit f4fcac7f authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira Committed by Amir MOHAMMADI
Browse files

[black]fy

parent 8ce11c18
Branches
Tags
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
......@@ -17,6 +17,7 @@ import logging
import os
import bob.io.base
import functools
logger = logging.getLogger(__name__)
......@@ -28,18 +29,21 @@ def scikit_to_bob_supervised(X, Y):
"""
# TODO: THIS IS VERY INNEFICI
logger.warning("INEFFICIENCY WARNING. HERE YOU ARE USING A HACK FOR USING BOB ALGORITHMS IN SCIKIT LEARN PIPELINES. \
WE RECOMMEND YOU TO PORT THIS ALGORITHM. DON'T BE LAZY :-)")
logger.warning(
"INEFFICIENCY WARNING. HERE YOU ARE USING A HACK FOR USING BOB ALGORITHMS IN SCIKIT LEARN PIPELINES. \
WE RECOMMEND YOU TO PORT THIS ALGORITHM. DON'T BE LAZY :-)"
)
bob_output = dict()
for x,y in zip(X, Y):
for x, y in zip(X, Y):
if y in bob_output:
bob_output[y] = numpy.vstack((bob_output[y], x.data))
else:
bob_output[y] = x.data
return [bob_output[k] for k in bob_output]
class LegacyProcessorMixin(TransformerMixin):
"""Class that wraps :py:class:`bob.bio.base.preprocessor.Preprocessor` and
:py:class:`bob.bio.base.extractor.Extractors`
......@@ -88,7 +92,9 @@ class LegacyProcessorMixin(TransformerMixin):
from bob.pipelines.mixins import CheckpointMixin, SampleMixin
class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
class LegacyAlgorithmMixin(CheckpointMixin, SampleMixin, BaseEstimator):
"""Class that wraps :py:class:`bob.bio.base.algorithm.Algoritm`
:py:method:`LegacyAlgorithmrMixin.fit` maps to :py:method:`bob.bio.base.algorithm.Algoritm.train_projector`
......@@ -122,9 +128,8 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
self.instance = None
self.projector_file = None
def fit(self, X, y=None, **fit_params):
self.projector_file = os.path.join(self.model_path, "Projector.hdf5")
if os.path.exists(self.projector_file):
return self
......@@ -146,12 +151,11 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
return self
def transform(self, X):
def _project_save_sample(sample):
# Project
projected_data = self.instance.project(sample.data)
#Checkpointing
# Checkpointing
path = self.make_path(sample)
bob.io.base.create_directories_safe(os.path.dirname(path))
f = bob.io.base.HDF5File(path, "w")
......@@ -191,11 +195,11 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
def get_reader(reader, path):
if(is_picklable(reader)):
if is_picklable(reader):
return functools.partial(reader, path)
else:
logger.warning(
f"The method {reader} is not picklable. Shiping its unbounded method to `DelayedSample`."
)
f"The method {reader} is not picklable. Shiping its unbounded method to `DelayedSample`."
)
reader = reader.__func__ # The reader object might not be picklable
return functools.partial(reader, None, path)
......@@ -21,6 +21,7 @@ class SamplePCA(SampleMixin, PCA):
"""
Enables SAMPLE handling for https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
"""
pass
......@@ -28,4 +29,5 @@ class CheckpointSamplePCA(CheckpointMixin, SampleMixin, PCA):
"""
Enables SAMPLE and CHECKPOINTIN handling for https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
"""
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment