Made Sample and Checkpoint Mixins transform samplesets

parent 2dc6e372
Pipeline #38186 passed with stage
in 3 minutes and 33 seconds
# vim: set fileencoding=utf-8 :
from .sample import Sample, DelayedSample
from .sample import Sample, DelayedSample, SampleSet
import os
import types
import cloudpickle
......@@ -190,9 +190,20 @@ class SampleMixin:
"""
def transform(self, samples):
features = super().transform([s.data for s in samples])
new_samples = [Sample(data, parent=s) for data, s in zip(features, samples)]
return new_samples
#if not isinstance(samples, list):
# samples = [samples]
# Transform eith samples or samplesets
if isinstance(samples[0], Sample) or isinstance(samples[0], DelayedSample):
features = super().transform([s.data for s in samples])
new_samples = [Sample(data, parent=s) for data, s in zip(features, samples)]
return new_samples
elif isinstance(samples[0], SampleSet):
return [SampleSet(self.transform(sset.samples), parent=sset)
for sset in samples]
else:
raise ValueError("Type for sample not supported %s" % type(samples))
def fit(self, samples, y=None):
return super().fit([s.data for s in samples])
......@@ -210,6 +221,7 @@ class CheckpointMixin:
def transform_one_sample(self, sample):
# Check if the sample is already processed.
path = self.make_path(sample)
if path is None or not os.path.isfile(path):
......@@ -221,8 +233,22 @@ class CheckpointMixin:
return new_sample
def transform_one_sample_set(self, sample_set):
samples = [self.transform_one_sample(s) for s in sample_set.samples]
return SampleSet(samples, parent=sample_set)
def transform(self, samples):
return [self.transform_one_sample(s) for s in samples]
if not isinstance(samples, list):
raise ValueError("It's expected a list, not %s" % type(samples))
if isinstance(samples[0], Sample) or isinstance(samples[0], DelayedSample):
return [self.transform_one_sample(s) for s in samples]
elif isinstance(samples[0], SampleSet):
return [self.transform_one_sample_set(s) for s in samples]
else:
raise ValueError("Type not allowed %s" % type(samples[0]))
def fit(self, samples, y=None):
if self.model_path is not None and os.path.isfile(self.model_path):
......@@ -246,8 +272,16 @@ class CheckpointMixin:
return key
def save(self, sample):
path = self.make_path(sample)
return bob.io.base.save(sample.data, path, create_directories=True)
if isinstance(sample, Sample):
path = self.make_path(sample)
return bob.io.base.save(sample.data, path, create_directories=True)
elif isinstance(sample, SampleSet):
for s in sample.samples:
path = self.make_path(s)
return bob.io.base.save(s.data, path, create_directories=True)
else:
raise ValueError("Type for sample not supported %s" % type(sample) )
def load(self, path):
key = self.recover_key_from_path(path)
......@@ -369,7 +403,7 @@ class DaskEstimatorMixin:
return self
def transform(self, X):
def _transf(X_line, dask_state):
def _transf(X_line, dask_state):
return super(DaskEstimatorMixin, dask_state).transform(X_line)
map_partitions = X.map_partitions(_transf, self._dask_state)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment