From 144e75a424056b3e0236ad0e4aa74c3342d6140a Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Mon, 23 Mar 2020 19:44:00 +0100 Subject: [PATCH] Tes cases for some mixins --- bob/bio/base/test/test_mixins.py | 68 ++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 bob/bio/base/test/test_mixins.py diff --git a/bob/bio/base/test/test_mixins.py b/bob/bio/base/test/test_mixins.py new file mode 100644 index 00000000..27701f9d --- /dev/null +++ b/bob/bio/base/test/test_mixins.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + +from bob.pipelines.sample import Sample, SampleSet, DelayedSample +import os +import numpy +import tempfile +from sklearn.utils.validation import check_is_fitted + +from bob.bio.base.mixins import Linearize, SampleLinearize, CheckpointSampleLinearize +def test_linearize_processor(): + + ## Test the transformer only + transformer = Linearize() + X = numpy.zeros(shape=(10,10)) + X_tr = transformer.transform(X) + assert X_tr.shape == (100,) + + + ## Test wrapped in to a Sample + sample = Sample(X, key="1") + transformer = SampleLinearize() + X_tr = transformer.transform([sample]) + assert X_tr[0].data.shape == (100,) + + ## Test checkpoint + with tempfile.TemporaryDirectory() as d: + transformer = CheckpointSampleLinearize(features_dir=d) + X_tr = transformer.transform([sample]) + assert X_tr[0].data.shape == (100,) + assert os.path.exists(os.path.join(d, "1.h5")) + + +from bob.bio.base.mixins import SamplePCA, CheckpointSamplePCA +def test_pca_processor(): + + ## Test wrapped in to a Sample + X = numpy.random.rand(100,10) + samples = [Sample(data, key=str(i)) for i, data in enumerate(X)] + + # fit + n_components = 2 + estimator = SamplePCA(n_components=n_components) + estimator = estimator.fit(samples) + + # https://scikit-learn.org/stable/modules/generated/sklearn.utils.validation.check_is_fitted.html + assert check_is_fitted(estimator, "n_components_") is None + + # transform + samples_tr = estimator.transform(samples) + assert samples_tr[0].data.shape == (n_components,) + + + ## Test Checkpoining + with tempfile.TemporaryDirectory() as d: + model_path = os.path.join(d, "model.pkl") + estimator = CheckpointSamplePCA(n_components=n_components, features_dir=d, model_path=model_path) + + # fit + estimator = estimator.fit(samples) + assert check_is_fitted(estimator, "n_components_") is None + assert os.path.exists(model_path) + + # transform + samples_tr = estimator.transform(samples) + assert samples_tr[0].data.shape == (n_components,) + assert os.path.exists(os.path.join(d, samples_tr[0].key+".h5")) -- GitLab