Skip to content
Snippets Groups Projects
Commit a1148796 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'sampleset-delayed-sample' into 'master'

Make a sampleset work transparently with list of DelayedSamples

See merge request !31
parents 7bf65c2b 6a9d848e
No related branches found
No related tags found
1 merge request!31Make a sampleset work transparently with list of DelayedSamples
Pipeline #40051 passed
from . import utils
from .sample import Sample, DelayedSample, SampleSet
from .sample import Sample, DelayedSample, SampleSet, sample_to_hdf5, hdf5_to_sample
from .wrappers import (
BaseWrapper,
DelayedSamplesCall,
......
......@@ -262,6 +262,7 @@ class SGEMultipleQueuesCluster(JobQueueCluster):
# removal before we remove it.
# Here the goal is to wait 2 minutes before scaling down since
# it is very expensive to get jobs on the SGE grid
self.adapt(minimum=min_jobs, maximum=max_jobs, wait_count=60, interval=1000)
def _get_worker_spec_options(self, job_spec):
......@@ -277,6 +278,9 @@ class SGEMultipleQueuesCluster(JobQueueCluster):
"io_big=TRUE," if "io_big" in job_spec and job_spec["io_big"] else ""
)
memory = _get_key_from_spec(job_spec, "memory")[:-1]
new_resource_spec += (f"mem_free={memory},")
queue = _get_key_from_spec(job_spec, "queue")
if queue != "all.q":
new_resource_spec += f"{queue}=TRUE"
......@@ -285,7 +289,7 @@ class SGEMultipleQueuesCluster(JobQueueCluster):
return {
"queue": queue,
"memory": _get_key_from_spec(job_spec, "memory"),
"memory": "0",
"cores": 1,
"processes": 1,
"log_directory": self.log_directory,
......@@ -440,7 +444,13 @@ class SchedulerResourceRestriction(Scheduler):
"""
def __init__(self, *args, **kwargs):
super(SchedulerResourceRestriction, self).__init__(*args, **kwargs)
super(SchedulerResourceRestriction, self).__init__(
idle_timeout=3600,
allowed_failures=500,
synchronize_worker_interval="240s",
*args,
**kwargs,
)
self.handlers[
"get_no_worker_tasks_resource_restrictions"
] = self.get_no_worker_tasks_resource_restrictions
......
......@@ -3,6 +3,8 @@
from collections.abc import MutableSequence, Sequence
from .utils import vstack_features
import numpy as np
import os
import h5py
def _copy_attributes(s, d):
......@@ -24,6 +26,31 @@ class _ReprMixin:
+ ")"
)
def __eq__(self, other):
sorted_self = {
k: v for k, v in sorted(self.__dict__.items(), key=lambda item: item[0])
}
sorted_other = {
k: v for k, v in sorted(other.__dict__.items(), key=lambda item: item[0])
}
for s, o in zip(sorted_self, sorted_other):
# Checking keys
if s != o:
return False
# Checking values
if isinstance(sorted_self[s], np.ndarray) and isinstance(
sorted_self[o], np.ndarray
):
if not np.allclose(sorted_self[s], sorted_other[o]):
return False
else:
if sorted_self[s] != sorted_other[o]:
return False
return True
class Sample(_ReprMixin):
"""Representation of sample. A Sample is a simple container that wraps a
......@@ -99,19 +126,28 @@ class SampleSet(MutableSequence, _ReprMixin):
_copy_attributes(self, parent.__dict__)
_copy_attributes(self, kwargs)
def _load(self):
if isinstance(self.samples, DelayedSample):
self.samples = self.samples.data
def __len__(self):
self._load()
return len(self.samples)
def __getitem__(self, item):
self._load()
return self.samples.__getitem__(item)
def __setitem__(self, key, item):
self._load()
return self.samples.__setitem__(key, item)
def __delitem__(self, item):
self._load()
return self.samples.__delitem__(item)
def insert(self, index, item):
self._load()
# if not item in self.samples:
self.samples.insert(index, item)
......@@ -136,5 +172,58 @@ class SampleBatch(Sequence, _ReprMixin):
def _reader(s):
# adding one more dimension to data so they get stacked sample-wise
return s.data[None, ...]
arr = vstack_features(_reader, self.samples, dtype=dtype)
return np.asarray(arr, dtype, *args, **kwargs)
def sample_to_hdf5(sample, hdf5):
"""
Saves the content of sample to hdf5 file
Parameters
----------
sample: :any:`Sample` or :any:`DelayedSample` or :any:`list`
Sample to be saved
hdf5: `h5py.File`
Pointer to a HDF5 file for writing
"""
if isinstance(sample, list):
for i, s in enumerate(sample):
group = hdf5.create_group(str(i))
sample_to_hdf5(s, group)
else:
for s in sample.__dict__:
hdf5[s] = sample.__dict__[s]
def hdf5_to_sample(hdf5):
"""
Reads the content of a HDF5File and returns a :any:`Sample`
Parameters
----------
hdf5: `h5py.File`
Pointer to a HDF5 file for reading
"""
# Checking if it has groups
has_groups = np.sum([isinstance(hdf5[k], h5py.Group) for k in hdf5.keys()]) > 0
if has_groups:
# If has groups, returns a list of Samples
samples = []
for k in hdf5.keys():
group = hdf5[k]
samples.append(hdf5_to_sample(group))
return samples
else:
# If hasn't groups, returns a sample
sample = Sample(None)
for k in hdf5.keys():
sample.__dict__[k] = hdf5[k].value
return sample
import bob.pipelines as mario
import numpy
from bob.pipelines import (
Sample,
DelayedSample,
SampleSet,
sample_to_hdf5,
hdf5_to_sample,
)
import bob.io.base
import numpy as np
import copy
import pickle
import tempfile
import functools
import os
import h5py
def test_sampleset_collection():
n_samples = 10
X = numpy.ones(shape=(n_samples, 2), dtype=int)
sampleset = mario.SampleSet(
[mario.Sample(data, key=str(i)) for i, data in enumerate(X)], key="1"
X = np.ones(shape=(n_samples, 2), dtype=int)
sampleset = SampleSet(
[Sample(data, key=str(i)) for i, data in enumerate(X)], key="1"
)
assert len(sampleset) == n_samples
# Testing insert
sample = mario.Sample(X, key=100)
sample = Sample(X, key=100)
sampleset.insert(1, sample)
assert len(sampleset) == n_samples + 1
......@@ -27,4 +39,49 @@ def test_sampleset_collection():
# Testing iterator
for i in sampleset:
assert isinstance(i, mario.Sample)
assert isinstance(i, Sample)
def _load(path):
return pickle.loads(open(path, "rb").read())
# Testing delayed sample in the sampleset
with tempfile.TemporaryDirectory() as dir_name:
samples = [Sample(data, key=str(i)) for i, data in enumerate(X)]
filename = os.path.join(dir_name, "samples.pkl")
with open(filename, "wb") as f:
f.write(pickle.dumps(samples))
sampleset = SampleSet(DelayedSample(functools.partial(_load, filename)), key=1)
assert len(sampleset) == n_samples
def test_sample_hdf5():
n_samples = 10
X = np.ones(shape=(n_samples, 2), dtype=int)
samples = [Sample(data, key=str(i), subject="Subject") for i, data in enumerate(X)]
with tempfile.TemporaryDirectory() as dir_name:
# Single sample
filename = os.path.join(dir_name, "sample.hdf5")
with h5py.File(filename, "w", driver="core") as hdf5:
sample_to_hdf5(samples[0], hdf5)
with h5py.File(filename, "r") as hdf5:
sample = hdf5_to_sample(hdf5)
assert sample == samples[0]
# List of samples
filename = os.path.join(dir_name, "samples.hdf5")
with h5py.File(filename, "w", driver="core") as hdf5:
sample_to_hdf5(samples, hdf5)
with h5py.File(filename, "r") as hdf5:
samples_deserialized = hdf5_to_sample(hdf5)
compare = [a == b for a, b in zip(samples_deserialized, samples)]
assert np.sum(compare) == 10
......@@ -217,11 +217,12 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
for s, p, should_compute in zip(samples, paths, should_compute_list):
if should_compute:
feat = computed_features[com_feat_index]
features.append(feat)
com_feat_index += 1
# save the computed feature
if p is not None:
self.save(feat)
feat = self.load(s, p)
features.append(feat)
else:
features.append(self.load(s, p))
return features
......@@ -398,16 +399,20 @@ class ToDaskBag(TransformerMixin, BaseEstimator):
Number of partitions used in :any:`dask.bag.from_sequence`
"""
def __init__(self, npartitions=None, **kwargs):
def __init__(self, npartitions=None, partition_size=None, **kwargs):
super().__init__(**kwargs)
self.npartitions = npartitions
self.partition_size = partition_size
def fit(self, X, y=None):
return self
def transform(self, X):
logger.debug(f"{_frmt(self)}.transform")
if self.partition_size is None:
return dask.bag.from_sequence(X, npartitions=self.npartitions)
else:
return dask.bag.from_sequence(X, partition_size=self.partition_size)
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
......
......@@ -36,6 +36,7 @@ requirements:
- dask-jobqueue
- distributed
- scikit-learn
- h5py
test:
imports:
......
......@@ -6,3 +6,4 @@ scikit-learn
dask
distributed
dask-jobqueue
h5py
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment