Skip to content
Snippets Groups Projects
Commit 38860b21 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Allowed to set the argument dask.bag.Bag.partition_size

parent 6d4994a8
No related branches found
No related tags found
1 merge request!31Make a sampleset work transparently with list of DelayedSamples
...@@ -216,12 +216,13 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): ...@@ -216,12 +216,13 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
features, com_feat_index = [], 0 features, com_feat_index = [], 0
for s, p, should_compute in zip(samples, paths, should_compute_list): for s, p, should_compute in zip(samples, paths, should_compute_list):
if should_compute: if should_compute:
feat = computed_features[com_feat_index] feat = computed_features[com_feat_index]
features.append(feat)
com_feat_index += 1 com_feat_index += 1
# save the computed feature # save the computed feature
if p is not None: if p is not None:
self.save(feat) self.save(feat)
feat = self.load(s, p)
features.append(feat)
else: else:
features.append(self.load(s, p)) features.append(self.load(s, p))
return features return features
...@@ -398,16 +399,20 @@ class ToDaskBag(TransformerMixin, BaseEstimator): ...@@ -398,16 +399,20 @@ class ToDaskBag(TransformerMixin, BaseEstimator):
Number of partitions used in :any:`dask.bag.from_sequence` Number of partitions used in :any:`dask.bag.from_sequence`
""" """
def __init__(self, npartitions=None, **kwargs): def __init__(self, npartitions=None, partition_size=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.npartitions = npartitions self.npartitions = npartitions
self.partition_size = partition_size
def fit(self, X, y=None): def fit(self, X, y=None):
return self return self
def transform(self, X): def transform(self, X):
logger.debug(f"{_frmt(self)}.transform") logger.debug(f"{_frmt(self)}.transform")
return dask.bag.from_sequence(X, npartitions=self.npartitions) if self.partition_size is None:
return dask.bag.from_sequence(X, npartitions=self.npartitions)
else:
return dask.bag.from_sequence(X, partition_size=self.partition_size)
def _more_tags(self): def _more_tags(self):
return {"stateless": True, "requires_fit": False} return {"stateless": True, "requires_fit": False}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment