Skip to content
Snippets Groups Projects
Commit 9c6a0be1 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'master' into 'xarray'

# Conflicts:
#   conda/meta.yaml
#   requirements.txt
parents 3d4331ea a1148796
No related branches found
No related tags found
1 merge request!30Add DatasetPipeline to work on xarray datasets
Pipeline #40052 failed
from . import utils from . import utils
from .sample import Sample, DelayedSample, SampleSet from .sample import Sample, DelayedSample, SampleSet, sample_to_hdf5, hdf5_to_sample
from .wrappers import ( from .wrappers import (
BaseWrapper, BaseWrapper,
DelayedSamplesCall, DelayedSamplesCall,
......
...@@ -262,6 +262,7 @@ class SGEMultipleQueuesCluster(JobQueueCluster): ...@@ -262,6 +262,7 @@ class SGEMultipleQueuesCluster(JobQueueCluster):
# removal before we remove it. # removal before we remove it.
# Here the goal is to wait 2 minutes before scaling down since # Here the goal is to wait 2 minutes before scaling down since
# it is very expensive to get jobs on the SGE grid # it is very expensive to get jobs on the SGE grid
self.adapt(minimum=min_jobs, maximum=max_jobs, wait_count=60, interval=1000) self.adapt(minimum=min_jobs, maximum=max_jobs, wait_count=60, interval=1000)
def _get_worker_spec_options(self, job_spec): def _get_worker_spec_options(self, job_spec):
...@@ -277,6 +278,9 @@ class SGEMultipleQueuesCluster(JobQueueCluster): ...@@ -277,6 +278,9 @@ class SGEMultipleQueuesCluster(JobQueueCluster):
"io_big=TRUE," if "io_big" in job_spec and job_spec["io_big"] else "" "io_big=TRUE," if "io_big" in job_spec and job_spec["io_big"] else ""
) )
memory = _get_key_from_spec(job_spec, "memory")[:-1]
new_resource_spec += (f"mem_free={memory},")
queue = _get_key_from_spec(job_spec, "queue") queue = _get_key_from_spec(job_spec, "queue")
if queue != "all.q": if queue != "all.q":
new_resource_spec += f"{queue}=TRUE" new_resource_spec += f"{queue}=TRUE"
...@@ -285,7 +289,7 @@ class SGEMultipleQueuesCluster(JobQueueCluster): ...@@ -285,7 +289,7 @@ class SGEMultipleQueuesCluster(JobQueueCluster):
return { return {
"queue": queue, "queue": queue,
"memory": _get_key_from_spec(job_spec, "memory"), "memory": "0",
"cores": 1, "cores": 1,
"processes": 1, "processes": 1,
"log_directory": self.log_directory, "log_directory": self.log_directory,
...@@ -440,7 +444,13 @@ class SchedulerResourceRestriction(Scheduler): ...@@ -440,7 +444,13 @@ class SchedulerResourceRestriction(Scheduler):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(SchedulerResourceRestriction, self).__init__(*args, **kwargs) super(SchedulerResourceRestriction, self).__init__(
idle_timeout=3600,
allowed_failures=500,
synchronize_worker_interval="240s",
*args,
**kwargs,
)
self.handlers[ self.handlers[
"get_no_worker_tasks_resource_restrictions" "get_no_worker_tasks_resource_restrictions"
] = self.get_no_worker_tasks_resource_restrictions ] = self.get_no_worker_tasks_resource_restrictions
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
from collections.abc import MutableSequence, Sequence from collections.abc import MutableSequence, Sequence
from .utils import vstack_features from .utils import vstack_features
import numpy as np import numpy as np
import os
import h5py
SAMPLE_DATA_ATTRS = ("data", "load", "samples", "_data") SAMPLE_DATA_ATTRS = ("data", "load", "samples", "_data")
...@@ -26,6 +28,31 @@ class _ReprMixin: ...@@ -26,6 +28,31 @@ class _ReprMixin:
+ ")" + ")"
) )
def __eq__(self, other):
sorted_self = {
k: v for k, v in sorted(self.__dict__.items(), key=lambda item: item[0])
}
sorted_other = {
k: v for k, v in sorted(other.__dict__.items(), key=lambda item: item[0])
}
for s, o in zip(sorted_self, sorted_other):
# Checking keys
if s != o:
return False
# Checking values
if isinstance(sorted_self[s], np.ndarray) and isinstance(
sorted_self[o], np.ndarray
):
if not np.allclose(sorted_self[s], sorted_other[o]):
return False
else:
if sorted_self[s] != sorted_other[o]:
return False
return True
class Sample(_ReprMixin): class Sample(_ReprMixin):
"""Representation of sample. A Sample is a simple container that wraps a """Representation of sample. A Sample is a simple container that wraps a
...@@ -101,19 +128,28 @@ class SampleSet(MutableSequence, _ReprMixin): ...@@ -101,19 +128,28 @@ class SampleSet(MutableSequence, _ReprMixin):
_copy_attributes(self, parent.__dict__) _copy_attributes(self, parent.__dict__)
_copy_attributes(self, kwargs) _copy_attributes(self, kwargs)
def _load(self):
if isinstance(self.samples, DelayedSample):
self.samples = self.samples.data
def __len__(self): def __len__(self):
self._load()
return len(self.samples) return len(self.samples)
def __getitem__(self, item): def __getitem__(self, item):
self._load()
return self.samples.__getitem__(item) return self.samples.__getitem__(item)
def __setitem__(self, key, item): def __setitem__(self, key, item):
self._load()
return self.samples.__setitem__(key, item) return self.samples.__setitem__(key, item)
def __delitem__(self, item): def __delitem__(self, item):
self._load()
return self.samples.__delitem__(item) return self.samples.__delitem__(item)
def insert(self, index, item): def insert(self, index, item):
self._load()
# if not item in self.samples: # if not item in self.samples:
self.samples.insert(index, item) self.samples.insert(index, item)
...@@ -138,5 +174,58 @@ class SampleBatch(Sequence, _ReprMixin): ...@@ -138,5 +174,58 @@ class SampleBatch(Sequence, _ReprMixin):
def _reader(s): def _reader(s):
# adding one more dimension to data so they get stacked sample-wise # adding one more dimension to data so they get stacked sample-wise
return s.data[None, ...] return s.data[None, ...]
arr = vstack_features(_reader, self.samples, dtype=dtype) arr = vstack_features(_reader, self.samples, dtype=dtype)
return np.asarray(arr, dtype, *args, **kwargs) return np.asarray(arr, dtype, *args, **kwargs)
def sample_to_hdf5(sample, hdf5):
"""
Saves the content of sample to hdf5 file
Parameters
----------
sample: :any:`Sample` or :any:`DelayedSample` or :any:`list`
Sample to be saved
hdf5: `h5py.File`
Pointer to a HDF5 file for writing
"""
if isinstance(sample, list):
for i, s in enumerate(sample):
group = hdf5.create_group(str(i))
sample_to_hdf5(s, group)
else:
for s in sample.__dict__:
hdf5[s] = sample.__dict__[s]
def hdf5_to_sample(hdf5):
"""
Reads the content of a HDF5File and returns a :any:`Sample`
Parameters
----------
hdf5: `h5py.File`
Pointer to a HDF5 file for reading
"""
# Checking if it has groups
has_groups = np.sum([isinstance(hdf5[k], h5py.Group) for k in hdf5.keys()]) > 0
if has_groups:
# If has groups, returns a list of Samples
samples = []
for k in hdf5.keys():
group = hdf5[k]
samples.append(hdf5_to_sample(group))
return samples
else:
# If hasn't groups, returns a sample
sample = Sample(None)
for k in hdf5.keys():
sample.__dict__[k] = hdf5[k].value
return sample
import bob.pipelines as mario from bob.pipelines import (
import numpy Sample,
DelayedSample,
SampleSet,
sample_to_hdf5,
hdf5_to_sample,
)
import bob.io.base
import numpy as np
import copy import copy
import pickle
import tempfile
import functools
import os
import h5py
def test_sampleset_collection(): def test_sampleset_collection():
n_samples = 10 n_samples = 10
X = numpy.ones(shape=(n_samples, 2), dtype=int) X = np.ones(shape=(n_samples, 2), dtype=int)
sampleset = mario.SampleSet( sampleset = SampleSet(
[mario.Sample(data, key=str(i)) for i, data in enumerate(X)], key="1" [Sample(data, key=str(i)) for i, data in enumerate(X)], key="1"
) )
assert len(sampleset) == n_samples assert len(sampleset) == n_samples
# Testing insert # Testing insert
sample = mario.Sample(X, key=100) sample = Sample(X, key=100)
sampleset.insert(1, sample) sampleset.insert(1, sample)
assert len(sampleset) == n_samples + 1 assert len(sampleset) == n_samples + 1
...@@ -27,4 +39,49 @@ def test_sampleset_collection(): ...@@ -27,4 +39,49 @@ def test_sampleset_collection():
# Testing iterator # Testing iterator
for i in sampleset: for i in sampleset:
assert isinstance(i, mario.Sample) assert isinstance(i, Sample)
def _load(path):
return pickle.loads(open(path, "rb").read())
# Testing delayed sample in the sampleset
with tempfile.TemporaryDirectory() as dir_name:
samples = [Sample(data, key=str(i)) for i, data in enumerate(X)]
filename = os.path.join(dir_name, "samples.pkl")
with open(filename, "wb") as f:
f.write(pickle.dumps(samples))
sampleset = SampleSet(DelayedSample(functools.partial(_load, filename)), key=1)
assert len(sampleset) == n_samples
def test_sample_hdf5():
n_samples = 10
X = np.ones(shape=(n_samples, 2), dtype=int)
samples = [Sample(data, key=str(i), subject="Subject") for i, data in enumerate(X)]
with tempfile.TemporaryDirectory() as dir_name:
# Single sample
filename = os.path.join(dir_name, "sample.hdf5")
with h5py.File(filename, "w", driver="core") as hdf5:
sample_to_hdf5(samples[0], hdf5)
with h5py.File(filename, "r") as hdf5:
sample = hdf5_to_sample(hdf5)
assert sample == samples[0]
# List of samples
filename = os.path.join(dir_name, "samples.hdf5")
with h5py.File(filename, "w", driver="core") as hdf5:
sample_to_hdf5(samples, hdf5)
with h5py.File(filename, "r") as hdf5:
samples_deserialized = hdf5_to_sample(hdf5)
compare = [a == b for a, b in zip(samples_deserialized, samples)]
assert np.sum(compare) == 10
...@@ -217,11 +217,12 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): ...@@ -217,11 +217,12 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
for s, p, should_compute in zip(samples, paths, should_compute_list): for s, p, should_compute in zip(samples, paths, should_compute_list):
if should_compute: if should_compute:
feat = computed_features[com_feat_index] feat = computed_features[com_feat_index]
features.append(feat)
com_feat_index += 1 com_feat_index += 1
# save the computed feature # save the computed feature
if p is not None: if p is not None:
self.save(feat) self.save(feat)
feat = self.load(s, p)
features.append(feat)
else: else:
features.append(self.load(s, p)) features.append(self.load(s, p))
return features return features
...@@ -398,16 +399,20 @@ class ToDaskBag(TransformerMixin, BaseEstimator): ...@@ -398,16 +399,20 @@ class ToDaskBag(TransformerMixin, BaseEstimator):
Number of partitions used in :any:`dask.bag.from_sequence` Number of partitions used in :any:`dask.bag.from_sequence`
""" """
def __init__(self, npartitions=None, **kwargs): def __init__(self, npartitions=None, partition_size=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.npartitions = npartitions self.npartitions = npartitions
self.partition_size = partition_size
def fit(self, X, y=None): def fit(self, X, y=None):
return self return self
def transform(self, X): def transform(self, X):
logger.debug(f"{_frmt(self)}.transform") logger.debug(f"{_frmt(self)}.transform")
if self.partition_size is None:
return dask.bag.from_sequence(X, npartitions=self.npartitions) return dask.bag.from_sequence(X, npartitions=self.npartitions)
else:
return dask.bag.from_sequence(X, partition_size=self.partition_size)
def _more_tags(self): def _more_tags(self):
return {"stateless": True, "requires_fit": False} return {"stateless": True, "requires_fit": False}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment