diff --git a/bob/pipelines/wrappers.py b/bob/pipelines/wrappers.py index 885acfa95be99edabc1db5e81c8ba6f3e54a7188..0297c78c97057ef3f0e39d2e34a100920b08919a 100644 --- a/bob/pipelines/wrappers.py +++ b/bob/pipelines/wrappers.py @@ -153,6 +153,12 @@ class SampleWrapper(BaseWrapper, TransformerMixin): return self._samples_transform(samples, "score") def fit(self, samples, y=None): + if y is not None: + raise TypeError( + "We don't accept `y` in fit arguments because " + "`y` should be part of the sample. To pass `y` " + "to the wrapped estimator, use `fit_extra_arguments`." + ) if is_estimator_stateless(self.estimator): return self @@ -163,7 +169,7 @@ class SampleWrapper(BaseWrapper, TransformerMixin): X = SampleBatch(samples) - self.estimator = self.estimator.fit(X, y=y, **kwargs) + self.estimator = self.estimator.fit(X, **kwargs) copy_learned_attributes(self.estimator, self) return self diff --git a/doc/checkpoint.rst b/doc/checkpoint.rst index 0646c7b3068e9cac772b8a50a6d3a9d8c59b5466..2f3ed31527dea1cf1426bcfdffbf94afbe26e0f8 100644 --- a/doc/checkpoint.rst +++ b/doc/checkpoint.rst @@ -40,7 +40,7 @@ transformer. ... print(f"Transforming {len(X)} samples ...") ... return np.array(X) + np.array(sample_specific_offsets) ... - ... def fit(self, X, y=None): + ... def fit(self, X): ... print("Fit was called!") ... return self diff --git a/doc/xarray.rst b/doc/xarray.rst index c6fd4bd0fbc5f9c0341fa2d9943d2a6cd53caf36..53cd9613f1fb3e267ec0f33cdd8077d8cd76a3b4 100644 --- a/doc/xarray.rst +++ b/doc/xarray.rst @@ -445,7 +445,7 @@ provide dask-ml estimators, set ``input_dask_array`` as ``True``. >>> ds = pipeline.fit(dataset).predict(dataset) >>> ds = ds.compute() >>> correct_classification = np.array(ds.data == ds.target).sum() - >>> correct_classification > 90 + >>> correct_classification > 85 True >>> ds.dims == {"sample": 150} True