Commit 2f46f325 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Transformer output handled as list

parent 4e39dcc8
......@@ -6,7 +6,6 @@ from functools import partial
import cloudpickle
import dask.bag
import numpy
from dask import delayed
from sklearn.base import BaseEstimator
......@@ -90,12 +89,14 @@ class DelayedSamplesCall:
if len(valid_samples) > 0:
X = SampleBatch(valid_samples, sample_attribute=self.sample_attribute)
self.output = self.func(X)
_check_n_input_output(valid_samples, self.output, self.func_name)
if self.output is None:
self.output = [None] * len(valid_samples)
# Rebuild the full batch of samples (including previously failed)
for i in invalid_ids:
self.output = numpy.insert(self.output, i, None, axis=0)
_check_n_input_output(self.samples, self.output, self.func_name)
# Rebuild the full batch of samples (include the previously failed)
if len(invalid_ids) > 0:
self.output = list(self.output)
for i in invalid_ids:
self.output.insert(i, None)
return self.output[index]
......@@ -322,11 +323,10 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
if should_compute:
feat = computed_features[com_feat_index]
com_feat_index += 1
# save the computed feature when valid (not NaN)
# save the computed feature when valid (not None)
if (
p is not None
and getattr(feat, self.sample_attribute) is not None
and not numpy.isnan(getattr(feat, self.sample_attribute)).any()
):
self.save(feat)
feat = self.load(s, p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment