test_wrappers.py 13.1 KB
Newer Older
1
import numpy as np
2
import os
3
import tempfile
4
import shutil
5

6 7 8
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.validation import check_array, check_is_fitted
9
from sklearn.pipeline import Pipeline
10 11
from sklearn.preprocessing import FunctionTransformer

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
12
import bob.pipelines as mario
13 14


15 16 17
def _offset_add_func(X, offset=1):
    return X + offset

18

19 20
class DummyWithFit(TransformerMixin, BaseEstimator):
    """See https://scikit-learn.org/stable/developers/develop.html and
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
21 22
    https://github.com/scikit-learn-contrib/project-
    template/blob/master/skltemplate/_template.py."""
23

24 25 26
    def fit(self, X, y=None):
        X = check_array(X)
        self.n_features_ = X.shape[1]
27

28 29 30 31
        self.model_ = np.ones((self.n_features_, 2))

        # Return the transformer
        return self
32

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
33
    def transform(self, X):
34 35 36 37 38 39 40 41 42
        # Check is fit had been called
        check_is_fitted(self, "n_features_")
        # Input validation
        X = check_array(X)
        # Check that the input is of the same shape as the one passed
        # during fit.
        if X.shape[1] != self.n_features_:
            raise ValueError(
                "Shape of input is different from what was seen" "in `fit`"
43
            )
44
        return X @ self.model_
45

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
46

47
class DummyTransformer(TransformerMixin, BaseEstimator):
48
    """See https://scikit-learn.org/stable/developers/develop.html and
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
49 50
    https://github.com/scikit-learn-contrib/project-
    template/blob/master/skltemplate/_template.py."""
51

52
    def __init__(self, picklable=True, i=None, **kwargs):
53
        super().__init__(**kwargs)
54 55
        self.picklable = picklable
        self.i = i
56 57 58 59

        if not picklable:
            import bob.core

60
            self.rng = bob.core.random.mt19937()
61

62
    def fit(self, X, y=None):
63 64
        return self

65 66 67 68 69
    def transform(self, X):

        # Input validation
        X = check_array(X)
        # Check that the input is of the same shape as the one passed
70
        # during fit.
71 72
        return _offset_add_func(X)

73 74 75
    def _more_tags(self):
        return {"stateless": True, "requires_fit": False}

76

77 78 79 80 81 82
def _assert_all_close_numpy_array(oracle, result):
    oracle, result = np.array(oracle), np.array(result)
    assert (
        oracle.shape == result.shape
    ), f"Expected: {oracle.shape} but got: {result.shape}"
    assert np.allclose(oracle, result), f"Expected: {oracle} but got: {result}"
83 84


85 86 87
def test_sklearn_compatible_estimator():
    # check classes for API consistency
    check_estimator(DummyWithFit)
88 89


90
def test_function_sample_transfomer():
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
91

92
    X = np.zeros(shape=(10, 2), dtype=int)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
93
    samples = [mario.Sample(data) for data in X]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
94

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
95
    transformer = mario.wrap(
96 97 98 99
        [FunctionTransformer, "sample"],
        func=_offset_add_func,
        kw_args=dict(offset=3),
        validate=True,
100
    )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
101

102 103
    features = transformer.transform(samples)
    _assert_all_close_numpy_array(X + 3, [s.data for s in features])
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
104

105 106
    features = transformer.fit_transform(samples)
    _assert_all_close_numpy_array(X + 3, [s.data for s in features])
107

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
108

109
def test_fittable_sample_transformer():
110

111
    X = np.ones(shape=(10, 2), dtype=int)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
112
    samples = [mario.Sample(data) for data in X]
113

114
    # Mixing up with an object
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
115
    transformer = mario.wrap([DummyWithFit, "sample"])
116 117
    features = transformer.fit(samples).transform(samples)
    _assert_all_close_numpy_array(X + 1, [s.data for s in features])
118

119 120
    features = transformer.fit_transform(samples)
    _assert_all_close_numpy_array(X + 1, [s.data for s in features])
121 122


123 124 125 126 127 128 129 130 131
def _assert_checkpoints(features, oracle, model_path, features_dir, stateless):
    _assert_all_close_numpy_array(oracle, [s.data for s in features])
    if stateless:
        assert not os.path.exists(model_path)
    else:
        assert os.path.exists(model_path), os.listdir(os.path.dirname(model_path))
    assert os.path.isdir(features_dir)
    for i in range(len(oracle)):
        assert os.path.isfile(os.path.join(features_dir, f"{i}.h5"))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
132

133

134 135
def _assert_delayed_samples(samples):
    for s in samples:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
136
        assert isinstance(s, mario.DelayedSample)
137

138 139 140 141

def test_checkpoint_function_sample_transfomer():

    X = np.arange(20, dtype=int).reshape(10, 2)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
142
    samples = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)]
143 144
    offset = 3
    oracle = X + offset
145 146 147 148 149

    with tempfile.TemporaryDirectory() as d:
        model_path = os.path.join(d, "model.pkl")
        features_dir = os.path.join(d, "features")

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
150
        transformer = mario.wrap(
151
            [FunctionTransformer, "sample", "checkpoint"],
152
            func=_offset_add_func,
153
            kw_args=dict(offset=offset),
154 155 156
            validate=True,
            model_path=model_path,
            features_dir=features_dir,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
157
        )
158 159 160 161 162 163 164 165 166 167 168 169 170

        features = transformer.transform(samples)
        _assert_checkpoints(features, oracle, model_path, features_dir, True)

        features = transformer.fit_transform(samples)
        _assert_checkpoints(features, oracle, model_path, features_dir, True)
        _assert_delayed_samples(features)

        # remove all files and call fit_transform again
        shutil.rmtree(d)
        features = transformer.fit_transform(samples)
        _assert_checkpoints(features, oracle, model_path, features_dir, True)

171
    # test when both model_path and features_dir is None
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
172
    transformer = mario.wrap(
173 174 175 176 177 178 179 180
        [FunctionTransformer, "sample", "checkpoint"],
        func=_offset_add_func,
        kw_args=dict(offset=offset),
        validate=True,
    )
    features = transformer.transform(samples)
    _assert_all_close_numpy_array(oracle, [s.data for s in features])

181 182 183

def test_checkpoint_fittable_sample_transformer():
    X = np.ones(shape=(10, 2), dtype=int)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
184
    samples = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)]
185 186 187 188 189 190
    oracle = X + 1

    with tempfile.TemporaryDirectory() as d:
        model_path = os.path.join(d, "model.pkl")
        features_dir = os.path.join(d, "features")

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
191
        transformer = mario.wrap(
192 193 194
            [DummyWithFit, "sample", "checkpoint"],
            model_path=model_path,
            features_dir=features_dir,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
195
        )
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
196
        assert not mario.utils.is_estimator_stateless(transformer)
197 198 199 200 201 202 203 204 205 206 207
        features = transformer.fit(samples).transform(samples)
        _assert_checkpoints(features, oracle, model_path, features_dir, False)

        features = transformer.fit_transform(samples)
        _assert_checkpoints(features, oracle, model_path, features_dir, False)
        _assert_delayed_samples(features)

        # remove all files and call fit_transform again
        shutil.rmtree(d)
        features = transformer.fit_transform(samples)
        _assert_checkpoints(features, oracle, model_path, features_dir, False)
208 209 210 211


def _build_estimator(path, i):
    base_dir = os.path.join(path, f"transformer{i}")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
212
    os.makedirs(base_dir, exist_ok=True)
213 214
    model_path = os.path.join(base_dir, "model.pkl")
    features_dir = os.path.join(base_dir, "features")
215

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
216
    transformer = mario.wrap(
217 218 219
        [DummyWithFit, "sample", "checkpoint"],
        model_path=model_path,
        features_dir=features_dir,
220
    )
221
    return transformer
222 223


224
def _build_transformer(path, i, picklable=True):
225

226
    features_dir = os.path.join(path, f"transformer{i}")
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
227
    estimator = mario.wrap(
228 229 230
        [DummyTransformer, "sample", "checkpoint"], i=i, features_dir=features_dir
    )
    return estimator
231 232


233 234 235
def test_checkpoint_fittable_pipeline():

    X = np.ones(shape=(10, 2), dtype=int)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
236
    samples = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)]
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
237 238 239
    samples_transform = [
        mario.Sample(data, key=str(i + 10)) for i, data in enumerate(X)
    ]
240 241 242 243 244 245 246 247 248 249 250 251
    oracle = X + 3

    with tempfile.TemporaryDirectory() as d:
        pipeline = Pipeline([(f"{i}", _build_estimator(d, i)) for i in range(2)])
        pipeline.fit(samples)

        transformed_samples = pipeline.transform(samples_transform)

        _assert_all_close_numpy_array(oracle, [s.data for s in transformed_samples])


def test_checkpoint_transform_pipeline():
252
    def _run(dask_enabled):
253

254
        X = np.ones(shape=(10, 2), dtype=int)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
255
        samples_transform = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)]
256 257
        offset = 2
        oracle = X + offset
258

259 260 261 262 263
        with tempfile.TemporaryDirectory() as d:
            pipeline = Pipeline(
                [(f"{i}", _build_transformer(d, i)) for i in range(offset)]
            )
            if dask_enabled:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
264
                pipeline = mario.wrap(["dask"], pipeline)
265
                transformed_samples = pipeline.transform(samples_transform).compute(
266 267 268 269
                    scheduler="single-threaded"
                )
            else:
                transformed_samples = pipeline.transform(samples_transform)
270

271
            _assert_all_close_numpy_array(oracle, [s.data for s in transformed_samples])
272

273 274
    _run(dask_enabled=True)
    _run(dask_enabled=False)
275 276


277 278 279
def test_checkpoint_fit_transform_pipeline():
    def _run(dask_enabled):
        X = np.ones(shape=(10, 2), dtype=int)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
280
        samples = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)]
281
        samples_transform = [
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
282
            mario.Sample(data, key=str(i + 10)) for i, data in enumerate(X)
283
        ]
284 285 286 287 288 289
        oracle = X + 2

        with tempfile.TemporaryDirectory() as d:
            fitter = ("0", _build_estimator(d, 0))
            transformer = ("1", _build_transformer(d, 1))
            pipeline = Pipeline([fitter, transformer])
290
            if dask_enabled:
291
                pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")], npartitions=1)
292
                pipeline = pipeline.fit(samples)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
293
                tags = mario.dask_tags(pipeline)
294

295
                assert len(tags) == 1, tags
296 297 298
                transformed_samples = pipeline.transform(samples_transform)

                transformed_samples = transformed_samples.compute(
299 300 301 302 303 304 305 306 307 308
                    scheduler="single-threaded"
                )
            else:
                pipeline = pipeline.fit(samples)
                transformed_samples = pipeline.transform(samples_transform)

            _assert_all_close_numpy_array(oracle, [s.data for s in transformed_samples])

    _run(dask_enabled=True)
    _run(dask_enabled=False)
309 310 311 312


def _get_local_client():
    from dask.distributed import Client, LocalCluster
313

314 315 316 317 318 319 320 321
    cluster = LocalCluster(
        nanny=False, processes=False, n_workers=1, threads_per_worker=1
    )
    cluster.scale_up(1)
    return Client(cluster)  # start local workers as threads


def test_checkpoint_fit_transform_pipeline_with_dask_non_pickle():
322 323
    def _run(dask_enabled):
        X = np.ones(shape=(10, 2), dtype=int)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
324
        samples = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)]
325
        samples_transform = [
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
326
            mario.Sample(data, key=str(i + 10)) for i, data in enumerate(X)
327
        ]
328 329 330 331 332 333
        oracle = X + 2

        with tempfile.TemporaryDirectory() as d:
            fitter = ("0", _build_estimator(d, 0))
            transformer = (
                "1",
334
                _build_transformer(d, 1, picklable=False),
335
            )
336

337 338 339
            pipeline = Pipeline([fitter, transformer])
            if dask_enabled:
                dask_client = _get_local_client()
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
340
                pipeline = mario.wrap(["dask"], pipeline)
341 342
                pipeline = pipeline.fit(samples)
                transformed_samples = pipeline.transform(samples_transform).compute(
343 344 345 346 347 348 349 350 351 352
                    scheduler=dask_client
                )
            else:
                pipeline = pipeline.fit(samples)
                transformed_samples = pipeline.transform(samples_transform)

            _assert_all_close_numpy_array(oracle, [s.data for s in transformed_samples])

    _run(True)
    _run(False)
353 354 355 356


def test_dask_checkpoint_transform_pipeline():
    X = np.ones(shape=(10, 2), dtype=int)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
357
    samples_transform = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)]
358
    with tempfile.TemporaryDirectory() as d:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
359 360
        bag_transformer = mario.ToDaskBag()
        estimator = mario.wrap(["dask"], _build_transformer(d, 0), transform_tag="CPU")
361
        X_tr = estimator.transform(bag_transformer.transform(samples_transform))
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
362
        assert len(mario.dask_tags(estimator)) == 1
363
        assert len(X_tr.compute(scheduler="single-threaded")) == 10
364 365 366 367 368 369


def test_checkpoint_transform_pipeline_with_sampleset():
    def _run(dask_enabled):

        X = np.ones(shape=(10, 2), dtype=int)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
370 371
        samples_transform = mario.SampleSet(
            [mario.Sample(data, key=str(i)) for i, data in enumerate(X)], key="1"
372 373 374 375 376 377 378 379 380
        )
        offset = 2
        oracle = X + offset

        with tempfile.TemporaryDirectory() as d:
            pipeline = Pipeline(
                [(f"{i}", _build_transformer(d, i)) for i in range(offset)]
            )
            if dask_enabled:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
381
                pipeline = mario.wrap(["dask"], pipeline)
382 383 384 385 386 387 388 389
                transformed_samples = pipeline.transform([samples_transform]).compute(
                    scheduler="single-threaded"
                )
            else:
                transformed_samples = pipeline.transform([samples_transform])

            _assert_all_close_numpy_array(
                oracle,
390
                [s.data for sample_set in transformed_samples for s in sample_set],
391 392 393 394 395
            )
            assert np.all([len(s) == 10 for s in transformed_samples])

    _run(dask_enabled=True)
    _run(dask_enabled=False)