From 144e75a424056b3e0236ad0e4aa74c3342d6140a Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Mon, 23 Mar 2020 19:44:00 +0100
Subject: [PATCH] Tes cases for some mixins

---
 bob/bio/base/test/test_mixins.py | 68 ++++++++++++++++++++++++++++++++
 1 file changed, 68 insertions(+)
 create mode 100644 bob/bio/base/test/test_mixins.py

diff --git a/bob/bio/base/test/test_mixins.py b/bob/bio/base/test/test_mixins.py
new file mode 100644
index 00000000..27701f9d
--- /dev/null
+++ b/bob/bio/base/test/test_mixins.py
@@ -0,0 +1,68 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+
+from bob.pipelines.sample import Sample, SampleSet, DelayedSample
+import os
+import numpy
+import tempfile
+from sklearn.utils.validation import check_is_fitted
+
+from bob.bio.base.mixins import Linearize, SampleLinearize, CheckpointSampleLinearize
+def test_linearize_processor():
+    
+    ## Test the transformer only
+    transformer = Linearize()
+    X = numpy.zeros(shape=(10,10))
+    X_tr = transformer.transform(X)
+    assert X_tr.shape == (100,)
+
+
+    ## Test wrapped in to a Sample
+    sample = Sample(X, key="1")
+    transformer = SampleLinearize()
+    X_tr = transformer.transform([sample])
+    assert X_tr[0].data.shape == (100,)
+
+    ## Test checkpoint    
+    with tempfile.TemporaryDirectory() as d:
+        transformer = CheckpointSampleLinearize(features_dir=d)
+        X_tr =  transformer.transform([sample])
+        assert X_tr[0].data.shape == (100,)
+        assert os.path.exists(os.path.join(d, "1.h5"))
+
+
+from bob.bio.base.mixins import SamplePCA, CheckpointSamplePCA
+def test_pca_processor():
+    
+    ## Test wrapped in to a Sample
+    X = numpy.random.rand(100,10)
+    samples = [Sample(data, key=str(i)) for i, data in enumerate(X)]
+
+    # fit
+    n_components = 2
+    estimator = SamplePCA(n_components=n_components)
+    estimator = estimator.fit(samples)
+    
+    # https://scikit-learn.org/stable/modules/generated/sklearn.utils.validation.check_is_fitted.html
+    assert check_is_fitted(estimator, "n_components_") is None
+    
+    # transform
+    samples_tr = estimator.transform(samples)
+    assert samples_tr[0].data.shape == (n_components,)
+    
+
+    ## Test Checkpoining
+    with tempfile.TemporaryDirectory() as d:        
+        model_path = os.path.join(d, "model.pkl")
+        estimator = CheckpointSamplePCA(n_components=n_components, features_dir=d, model_path=model_path)
+
+        # fit
+        estimator = estimator.fit(samples)
+        assert check_is_fitted(estimator, "n_components_") is None
+        assert os.path.exists(model_path)
+        
+        # transform
+        samples_tr = estimator.transform(samples)
+        assert samples_tr[0].data.shape == (n_components,)        
+        assert os.path.exists(os.path.join(d, samples_tr[0].key+".h5"))
-- 
GitLab