Skip to content
Snippets Groups Projects
Commit 6a01d66e authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'revert-bd57d13a' into 'master'

Revert "For some reason, the class information is not passed in the sample wrapper"

See merge request !37
parents bd57d13a 0dfd922a
Branches
Tags
1 merge request!37Revert "For some reason, the class information is not passed in the sample wrapper"
Pipeline #43898 passed
......@@ -153,6 +153,12 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
return self._samples_transform(samples, "score")
def fit(self, samples, y=None):
if y is not None:
raise TypeError(
"We don't accept `y` in fit arguments because "
"`y` should be part of the sample. To pass `y` "
"to the wrapped estimator, use `fit_extra_arguments`."
)
if is_estimator_stateless(self.estimator):
return self
......@@ -163,7 +169,7 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
X = SampleBatch(samples)
self.estimator = self.estimator.fit(X, y=y, **kwargs)
self.estimator = self.estimator.fit(X, **kwargs)
copy_learned_attributes(self.estimator, self)
return self
......
......@@ -40,7 +40,7 @@ transformer.
... print(f"Transforming {len(X)} samples ...")
... return np.array(X) + np.array(sample_specific_offsets)
...
... def fit(self, X, y=None):
... def fit(self, X):
... print("Fit was called!")
... return self
......
......@@ -445,7 +445,7 @@ provide dask-ml estimators, set ``input_dask_array`` as ``True``.
>>> ds = pipeline.fit(dataset).predict(dataset)
>>> ds = ds.compute()
>>> correct_classification = np.array(ds.data == ds.target).sum()
>>> correct_classification > 90
>>> correct_classification > 85
True
>>> ds.dims == {"sample": 150}
True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment