Skip to content
Snippets Groups Projects
Commit 3e91ccfc authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira Committed by Amir MOHAMMADI
Browse files

Created Processors for linearized and PCA

parent bdb065f4
Branches
Tags
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
......@@ -4,6 +4,7 @@ from . import preprocessor
from . import extractor
from . import algorithm
from . import annotator
from . import processor
from . import script
from . import test
......
from .linearize import Linearize, SampleLinearize, CheckpointSampleLinearize
from .pca import CheckpointSamplePCA, SamplePCA
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from bob.pipelines.processor import CheckpointMixin, SampleMixin
from sklearn.base import TransformerMixin
from sklearn.utils.validation import check_array, check_is_fitted
import numpy
class Linearize(TransformerMixin):
"""Extracts features by simply concatenating all elements of the data into one long vector.
If a ``dtype`` is specified in the contructor, it is assured that the resulting
"""
def fit(self, X, y=None):
return self
def transform(self, X):
"""__call__(data) -> data
Takes data of arbitrary dimensions and linearizes it into a 1D vector; enforcing the data type, if desired.
Parameters:
-----------
data : :py:class:`numpy.ndarray`
The preprocessed data to be transformed into one vector.
Returns:
--------
data : 1D :py:class:`numpy.ndarray`
The extracted feature vector, of the desired ``dtype`` (if specified).
"""
X = check_array(X, allow_nd=True)
if X.ndim == 2:
return numpy.reshape(X, X.size)
else:
# Reshaping n-dimensional arrays assuming that the
# first axis corresponds to the number of samples
return numpy.reshape(X, (X.shape[0], numpy.prod(X.shape[1:])))
class SampleLinearize(SampleMixin, Linearize):
pass
class CheckpointSampleLinearize(CheckpointMixin, SampleMixin, Linearize):
pass
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
"""
TODO: This should be deployed in bob.pipelines
"""
from bob.pipelines.processor import CheckpointMixin, SampleMixin
from sklearn.base import TransformerMixin
from sklearn.decomposition import PCA
import numpy
"""
Wraps the
"""
class SamplePCA(SampleMixin, PCA):
"""
Enables SAMPLE handling for https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
"""
pass
class CheckpointSamplePCA(CheckpointMixin, SampleMixin, PCA):
"""
Enables SAMPLE and CHECKPOINTIN handling for https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
"""
pass
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from bob.pipelines.sample import Sample, SampleSet, DelayedSample
import os
import numpy
import tempfile
from sklearn.utils.validation import check_is_fitted
from bob.bio.base.processor import Linearize, SampleLinearize, CheckpointSampleLinearize
def test_linearize_processor():
## Test the transformer only
transformer = Linearize()
X = numpy.zeros(shape=(10,10))
X_tr = transformer.transform(X)
assert X_tr.shape == (100,)
## Test wrapped in to a Sample
sample = Sample(X, key="1")
transformer = SampleLinearize()
X_tr = transformer.transform([sample])
assert X_tr[0].data.shape == (100,)
## Test checkpoint
with tempfile.TemporaryDirectory() as d:
transformer = CheckpointSampleLinearize(features_dir=d)
X_tr = transformer.transform([sample])
assert X_tr[0].data.shape == (100,)
assert os.path.exists(os.path.join(d, "1.h5"))
from bob.bio.base.processor import SamplePCA, CheckpointSamplePCA
def test_pca_processor():
## Test wrapped in to a Sample
X = numpy.random.rand(100,10)
samples = [Sample(data, key=str(i)) for i, data in enumerate(X)]
# fit
n_components = 2
estimator = SamplePCA(n_components=n_components)
estimator = estimator.fit(samples)
# https://scikit-learn.org/stable/modules/generated/sklearn.utils.validation.check_is_fitted.html
assert check_is_fitted(estimator, "n_components_") is None
# transform
samples_tr = estimator.transform(samples)
assert samples_tr[0].data.shape == (n_components,)
## Test Checkpoining
with tempfile.TemporaryDirectory() as d:
model_path = os.path.join(d, "model.pkl")
estimator = CheckpointSamplePCA(n_components=n_components, features_dir=d, model_path=model_path)
# fit
estimator = estimator.fit(samples)
assert check_is_fitted(estimator, "n_components_") is None
assert os.path.exists(model_path)
# transform
samples_tr = estimator.transform(samples)
assert samples_tr[0].data.shape == (n_components,)
assert os.path.exists(os.path.join(d, samples_tr[0].key+".h5"))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment