Commit 611f8867 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Remove unused transformers, rename is_instance_nested, fix API docs

parent ecc8c7b4
Pipeline #61358 passed with stage
in 12 minutes and 25 seconds
......@@ -29,8 +29,8 @@ from .wrappers import ( # noqa: F401
dask_tags,
estimator_requires_fit,
get_bob_tags,
is_instance_nested,
is_pipeline_wrapped,
isinstance_nested,
)
......
import os
import tempfile
import numpy as np
from sklearn.utils.validation import check_is_fitted
import bob.pipelines as mario
def test_linearize():
def _assert(Xt, oracle):
assert np.allclose(Xt, oracle), (Xt, oracle)
X = np.zeros(shape=(10, 10, 10))
oracle = X.reshape((10, -1))
# Test the transformer only
transformer = mario.transformers.Linearize()
X_tr = transformer.transform(X)
_assert(X_tr, oracle)
# Test wrapped in to a Sample
samples = [mario.Sample(x, key=f"{i}") for i, x in enumerate(X)]
transformer = mario.transformers.SampleLinearize()
X_tr = transformer.transform(samples)
_assert([s.data for s in X_tr], oracle)
# Test checkpoint
with tempfile.TemporaryDirectory() as d:
transformer = mario.transformers.CheckpointSampleLinearize(
features_dir=d
)
X_tr = transformer.transform(samples)
_assert([s.data for s in X_tr], oracle)
assert os.path.exists(os.path.join(d, "1.h5"))
def test_pca():
# Test wrapped in to a Sample
X = np.random.rand(100, 10)
samples = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)]
# fit
n_components = 2
estimator = mario.transformers.SamplePCA(n_components=n_components)
estimator = estimator.fit(samples)
# https://scikit-learn.org/stable/modules/generated/sklearn.utils.validation.check_is_fitted.html
assert check_is_fitted(estimator, "n_components_") is None
# transform
samples_tr = estimator.transform(samples)
assert samples_tr[0].data.shape == (n_components,)
# Test Checkpoining
with tempfile.TemporaryDirectory() as d:
model_path = os.path.join(d, "model.pkl")
estimator = mario.transformers.CheckpointSamplePCA(
n_components=n_components, features_dir=d, model_path=model_path
)
# fit
estimator = estimator.fit(samples)
assert check_is_fitted(estimator, "n_components_") is None
assert os.path.exists(model_path)
# transform
samples_tr = estimator.transform(samples)
assert samples_tr[0].data.shape == (n_components,)
assert os.path.exists(os.path.join(d, samples_tr[0].key + ".h5"))
def test_str_to_types():
samples = [
mario.Sample(None, id="1", flag="True"),
mario.Sample(None, id="2", flag="False"),
]
transformer = mario.transformers.Str_To_Types(
fieldtypes=dict(id=int, flag=mario.transformers.str_to_bool)
)
transformer.transform(samples)
assert samples[0].id == 1
assert samples[0].flag is True
assert samples[1].id == 2
assert samples[1].flag is False
......@@ -74,7 +74,7 @@ def test_is_pipeline_wrapped():
)
def test_isinstance_nested():
def test_is_instance_nested():
class A:
pass
......@@ -87,14 +87,14 @@ def test_isinstance_nested():
self.o = o
o = C(B(A()))
assert mario.isinstance_nested(o, "o", C)
assert mario.isinstance_nested(o, "o", B)
assert mario.isinstance_nested(o, "o", A)
assert mario.is_instance_nested(o, "o", C)
assert mario.is_instance_nested(o, "o", B)
assert mario.is_instance_nested(o, "o", A)
o = C(B(object))
assert mario.isinstance_nested(o, "o", C)
assert mario.isinstance_nested(o, "o", B)
assert not mario.isinstance_nested(o, "o", A)
assert mario.is_instance_nested(o, "o", C)
assert mario.is_instance_nested(o, "o", B)
assert not mario.is_instance_nested(o, "o", A)
def test_break_sample_set():
......
from .file_loader import FileLoader
from .linearize import CheckpointSampleLinearize, Linearize, SampleLinearize
from .pca import CheckpointSamplePCA, SamplePCA
from .str_to_types import Str_To_Types # noqa: F401
from .str_to_types import str_to_bool # noqa: F401
......@@ -23,14 +20,7 @@ def __appropriate__(*args):
obj.__module__ = __name__
__appropriate__(
Linearize,
SampleLinearize,
CheckpointSampleLinearize,
CheckpointSamplePCA,
SamplePCA,
FileLoader,
)
__appropriate__()
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith("_")]
import os
from sklearn.preprocessing import FunctionTransformer
from bob.io.base import load
from ..wrappers import wrap
def file_loader(files, original_directory, original_extension):
data = []
for path in files:
d = load(os.path.join(original_directory, path + original_extension))
data.append(d)
return data
def FileLoader(original_directory, original_extension=None, **kwargs):
original_directory = original_directory or ""
original_extension = original_extension or ""
return FunctionTransformer(
file_loader,
validate=False,
kw_args=dict(
original_directory=original_directory,
original_extension=original_extension,
),
)
def key_based_file_loader(original_directory, original_extension):
transformer = FileLoader(original_directory, original_extension)
# transformer takes as input sample.key and its output is saved in sample.data
transformer = wrap(["sample"], transformer, input_attribute="key")
return transformer
import numpy as np
from sklearn.preprocessing import FunctionTransformer
from ..wrappers import wrap
def linearize(X):
X = np.asarray(X)
return np.reshape(X, (X.shape[0], -1))
class Linearize(FunctionTransformer):
"""Extracts features by simply concatenating all elements of the data into
one long vector."""
def __init__(self, **kwargs):
super().__init__(func=linearize, **kwargs)
def SampleLinearize(**kwargs):
return wrap([Linearize, "sample"], **kwargs)
def CheckpointSampleLinearize(**kwargs):
return wrap([Linearize, "sample", "checkpoint"], **kwargs)
from sklearn.decomposition import PCA
from ..wrappers import wrap
def SamplePCA(**kwargs):
"""Enables SAMPLE handling for :any:`sklearn.decomposition.PCA`"""
return wrap([PCA, "sample"], **kwargs)
def CheckpointSamplePCA(**kwargs):
"""Enables SAMPLE and CHECKPOINTIN handling for
:any:`sklearn.decomposition.PCA`"""
return wrap([PCA, "sample", "checkpoint"], **kwargs)
......@@ -58,7 +58,7 @@ def flatten_samplesets(samplesets):
Parameters
----------
samplesets: list of SampleSets
samplesets: list of :obj:`bob.pipelines.SampleSet`
Input list of SampleSets (with one or multiple samples in each SampleSet
"""
......
......@@ -648,7 +648,7 @@ def _update_estimator(estimator, loaded_estimator):
def is_checkpointed(estimator):
return isinstance_nested(estimator, "estimator", CheckpointWrapper)
return is_instance_nested(estimator, "estimator", CheckpointWrapper)
def getattr_nested(estimator, attr):
......@@ -1055,7 +1055,7 @@ def estimator_requires_fit(estimator):
)
# If the estimator is wrapped, check the wrapped estimator
if isinstance_nested(
if is_instance_nested(
estimator, "estimator", (SampleWrapper, CheckpointWrapper, DaskWrapper)
):
return estimator_requires_fit(estimator.estimator)
......@@ -1066,7 +1066,7 @@ def estimator_requires_fit(estimator):
# We check for the FunctionTransformer since theoretically it
# does require fit but it does not really need it.
if isinstance_nested(estimator, "estimator", FunctionTransformer):
if is_instance_nested(estimator, "estimator", FunctionTransformer):
return False
# if the estimator does not require fit, don't call fit
......@@ -1075,7 +1075,7 @@ def estimator_requires_fit(estimator):
return tags["requires_fit"]
def isinstance_nested(instance, attribute, isinstance_of):
def is_instance_nested(instance, attribute, isinstance_of):
"""
Check if an object and its nested objects is an instance of a class.
......@@ -1107,7 +1107,7 @@ def isinstance_nested(instance, attribute, isinstance_of):
return True
else:
# Recursive search
return isinstance_nested(
return is_instance_nested(
getattr(instance, attribute), attribute, isinstance_of
)
......@@ -1123,19 +1123,19 @@ def is_pipeline_wrapped(estimator, wrapper):
estimator: sklearn.pipeline.Pipeline
Pipeline to be checked
wrapper: class
Wrapper to be checked
wrapper: type
The Wrapper class or a tuple of classes to be checked
Returns
-------
list
Returns a list of boolean values, where each value indicates if the corresponding estimator is wrapped or not
"""
if not isinstance(estimator, Pipeline):
raise ValueError(f"{estimator} is not an instance of Pipeline")
return [
isinstance_nested(trans, "estimator", wrapper)
is_instance_nested(trans, "estimator", wrapper)
for _, _, trans in estimator._iter()
]
......@@ -29,7 +29,7 @@ def save(data, path):
def load(path):
with h5py.File(path, "r") as f:
data = np.array(f["array"])
data = f["array"][()]
return data
......
......@@ -3,27 +3,81 @@
Python API for bob.pipelines
============================
Main module
-----------
Summary
=======
.. automodule:: bob.pipelines
Sample's API
------------
.. autosummary::
bob.pipelines.Sample
bob.pipelines.DelayedSample
bob.pipelines.SampleSet
bob.pipelines.DelayedSampleSet
bob.pipelines.DelayedSampleSetCached
bob.pipelines.SampleBatch
Heterogeneous SGE
Wrapper's API
-------------
.. autosummary::
bob.pipelines.wrap
bob.pipelines.BaseWrapper
bob.pipelines.SampleWrapper
bob.pipelines.CheckpointWrapper
bob.pipelines.DaskWrapper
bob.pipelines.ToDaskBag
bob.pipelines.DelayedSamplesCall
Database's API
--------------
.. autosummary::
bob.pipelines.datasets.FileListDatabase
bob.pipelines.datasets.FileListToSamples
bob.pipelines.datasets.CSVToSamples
Transformers' API
-----------------
.. autosummary::
bob.pipelines.transformers.Str_To_Types
bob.pipelines.transformers.str_to_bool
Xarray's API
------------
.. autosummary::
bob.pipelines.xarray.samples_to_dataset
bob.pipelines.xarray.DatasetPipeline
bob.pipelines.xarray.Block
Utilities
---------
.. autosummary::
bob.pipelines.assert_picklable
bob.pipelines.check_parameter_for_validity
bob.pipelines.check_parameters_for_validity
bob.pipelines.dask_tags
bob.pipelines.estimator_requires_fit
bob.pipelines.flatten_samplesets
bob.pipelines.get_bob_tags
bob.pipelines.hash_string
bob.pipelines.is_instance_nested
bob.pipelines.is_picklable
bob.pipelines.is_pipeline_wrapped
Main module
===========
.. automodule:: bob.pipelines
Heterogeneous SGE
=================
.. automodule:: bob.pipelines.distributed.sge
Transformers
------------
============
.. automodule:: bob.pipelines.transformers
xarray Wrapper
--------------
==============
.. automodule:: bob.pipelines.xarray
Filelist Datasets
-----------------
=================
.. automodule:: bob.pipelines.datasets
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment